From b2f94476a41cccd178ebad0154d9438cc4d0d6f9 Mon Sep 17 00:00:00 2001 From: suquark Date: Wed, 20 Feb 2019 11:42:18 +0800 Subject: [PATCH 01/52] Replace native socket operations with asio revert some changes replace event loop with asio update plasma store protocol fix qualifiers update plasma store client protocol Remove all native socket operations. Implement general io support Fix bugs fix all compiling bugs fix bug Fix all tests. Add license header. try to fix cmake try to make asio standalone simplify code add license update url lint lint & fix fix restore entrypoint remove unused unix headers fix Update LICENSE fix rename handle signal move the function to its original place fix doc hide classes stop installing asio headers fix doc reverse changes minor fix tiny fix fix comments minor fix resolve conflicts fix optimize cmake fix update formatter --- LICENSE.txt | 33 +- cpp/apidoc/tutorials/plasma.md | 15 +- cpp/cmake_modules/ThirdpartyToolchain.cmake | 27 + cpp/src/arrow/status.h | 7 + cpp/src/plasma/CMakeLists.txt | 22 +- cpp/src/plasma/client.cc | 156 ++--- cpp/src/plasma/client.h | 15 +- cpp/src/plasma/common.cc | 4 - cpp/src/plasma/dlmalloc.cc | 2 + cpp/src/plasma/events.cc | 107 --- cpp/src/plasma/events.h | 111 ---- cpp/src/plasma/eviction_policy.cc | 1 + cpp/src/plasma/eviction_policy.h | 26 +- cpp/src/plasma/fling.h | 10 +- cpp/src/plasma/io.cc | 241 ------- cpp/src/plasma/io.h | 70 -- cpp/src/plasma/io/basic_connection.cc | 247 +++++++ cpp/src/plasma/io/basic_connection.h | 135 ++++ cpp/src/plasma/io/connection.cc | 327 +++++++++ cpp/src/plasma/io/connection.h | 187 ++++++ cpp/src/plasma/malloc.cc | 1 + cpp/src/plasma/malloc.h | 4 +- cpp/src/plasma/plasma.cc | 45 -- cpp/src/plasma/plasma.h | 58 +- cpp/src/plasma/protocol.cc | 285 +++++--- cpp/src/plasma/protocol.h | 154 +++-- cpp/src/plasma/store.cc | 701 +++++++++----------- cpp/src/plasma/store.h | 172 +++-- cpp/src/plasma/test/client_tests.cc | 15 +- cpp/src/plasma/test/serialization_tests.cc | 204 +++--- cpp/src/plasma/thirdparty/ae/ae.c | 465 ------------- cpp/src/plasma/thirdparty/ae/ae.h | 123 ---- cpp/src/plasma/thirdparty/ae/ae_epoll.c | 137 ---- cpp/src/plasma/thirdparty/ae/ae_evport.c | 320 --------- cpp/src/plasma/thirdparty/ae/ae_kqueue.c | 138 ---- cpp/src/plasma/thirdparty/ae/ae_select.c | 106 --- cpp/src/plasma/thirdparty/ae/config.h | 54 -- cpp/src/plasma/thirdparty/ae/zmalloc.h | 45 -- dev/release/rat_exclude_files.txt | 8 - python/pyarrow/_plasma.pyx | 16 +- 40 files changed, 1816 insertions(+), 2978 deletions(-) delete mode 100644 cpp/src/plasma/events.cc delete mode 100644 cpp/src/plasma/events.h delete mode 100644 cpp/src/plasma/io.cc delete mode 100644 cpp/src/plasma/io.h create mode 100644 cpp/src/plasma/io/basic_connection.cc create mode 100644 cpp/src/plasma/io/basic_connection.h create mode 100644 cpp/src/plasma/io/connection.cc create mode 100644 cpp/src/plasma/io/connection.h delete mode 100644 cpp/src/plasma/thirdparty/ae/ae.c delete mode 100644 cpp/src/plasma/thirdparty/ae/ae.h delete mode 100644 cpp/src/plasma/thirdparty/ae/ae_epoll.c delete mode 100644 cpp/src/plasma/thirdparty/ae/ae_evport.c delete mode 100644 cpp/src/plasma/thirdparty/ae/ae_kqueue.c delete mode 100644 cpp/src/plasma/thirdparty/ae/ae_select.c delete mode 100644 cpp/src/plasma/thirdparty/ae/config.h delete mode 100644 cpp/src/plasma/thirdparty/ae/zmalloc.h diff --git a/LICENSE.txt b/LICENSE.txt index d66d4ba3818..2d7bb421b4d 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -221,37 +221,6 @@ limitations under the License. -------------------------------------------------------------------------------- -src/plasma/thirdparty/ae: Modified / 3-Clause BSD - -Copyright (c) 2006-2010, Salvatore Sanfilippo -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of Redis nor the names of its contributors may be used - to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - src/plasma/thirdparty/dlmalloc.c: CC0 This is a version (aka dlmalloc) of malloc/free/realloc written by @@ -378,7 +347,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- -This project includes code from the Boost project +This project includes code from the Boost project and independent Asio headers Boost Software License - Version 1.0 - August 17th, 2003 diff --git a/cpp/apidoc/tutorials/plasma.md b/cpp/apidoc/tutorials/plasma.md index 40c5a10603e..21ce303f490 100644 --- a/cpp/apidoc/tutorials/plasma.md +++ b/cpp/apidoc/tutorials/plasma.md @@ -384,16 +384,14 @@ sealed in the object store. This may especially be handy when your program is collaborating with other Plasma clients, and needs to know when they make objects available. -First, you can subscribe your current Plasma client to such notifications -by getting a file descriptor: +First, you can subscribe your current Plasma client to such notifications: ```cpp // Start receiving notifications into file_descriptor. -int fd; -ARROW_CHECK_OK(client.Subscribe(&fd)); +ARROW_CHECK_OK(client.Subscribe()); ``` -Once you have the file descriptor, you can have your current Plasma client +Once you have subscribed, you can have your current Plasma client wait to receive the next object notification. Object notifications include information such as Object ID, data size, and metadata size of the next newly available object: @@ -404,7 +402,7 @@ the next newly available object: ObjectID object_id; int64_t data_size; int64_t metadata_size; -ARROW_CHECK_OK(client.GetNotification(fd, &object_id, &data_size, &metadata_size)); +ARROW_CHECK_OK(client.GetNotification(&object_id, &data_size, &metadata_size)); // Get the newly available object. ObjectBuffer object_buffer; @@ -423,14 +421,13 @@ int main(int argc, char** argv) { PlasmaClient client; ARROW_CHECK_OK(client.Connect("/tmp/plasma")); - int fd; - ARROW_CHECK_OK(client.Subscribe(&fd)); + ARROW_CHECK_OK(client.Subscribe()); ObjectID object_id; int64_t data_size; int64_t metadata_size; while (true) { - ARROW_CHECK_OK(client.GetNotification(fd, &object_id, &data_size, &metadata_size)); + ARROW_CHECK_OK(client.GetNotification(&object_id, &data_size, &metadata_size)); std::cout << "Received object notification for object_id = " << object_id.hex() << ", with data_size = " << data_size diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index bdb122b2710..f440e3ccd96 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -2419,6 +2419,33 @@ if(ARROW_ORC) endif() endif() +# ---------------------------------------------------------------------- +# Plasma + +if(ARROW_PLASMA) + externalproject_add(asio_ep + URL + https://github.com/chriskohlhoff/asio/archive/asio-1-12-2.zip + CONFIGURE_COMMAND + "" # No autogen since we use asio in header-only way. + BUILD_IN_SOURCE + 1 + BUILD_COMMAND + "" + INSTALL_COMMAND + cmake + -E + copy + asio/include/asio.hpp + ${PROJECT_BINARY_DIR}/src + && + cmake + -E + copy_directory + asio/include/asio + ${PROJECT_BINARY_DIR}/src/asio/) +endif() + # Write out the package configurations. configure_file("src/arrow/util/config.h.cmake" "src/arrow/util/config.h") diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 96b018b650d..984d8b541d2 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -91,6 +91,7 @@ enum class StatusCode : char { SerializationError = 11, PythonError = 12, RError = 13, + ProtocolError = 14, PlasmaObjectExists = 20, PlasmaObjectNonexistent = 21, PlasmaStoreFull = 22, @@ -218,6 +219,12 @@ class ARROW_EXPORT Status { return Status(StatusCode::RError, util::StringBuilder(std::forward(args)...)); } + template + static Status ProtocolError(Args&&... args) { + return Status(StatusCode::ProtocolError, + util::StringBuilder(std::forward(args)...)); + } + template static Status PlasmaObjectExists(Args&&... args) { return Status(StatusCode::PlasmaObjectExists, diff --git a/cpp/src/plasma/CMakeLists.txt b/cpp/src/plasma/CMakeLists.txt index 729fba7e944..baba62710c9 100644 --- a/cpp/src/plasma/CMakeLists.txt +++ b/cpp/src/plasma/CMakeLists.txt @@ -61,15 +61,19 @@ add_custom_command( set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") -set(PLASMA_SRCS client.cc common.cc fling.cc io.cc malloc.cc plasma.cc protocol.cc) +# Set compiling options for asio headers. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion -Wno-documentation") + +set(PLASMA_IO_SRCS fling.cc io/basic_connection.cc io/connection.cc protocol.cc) + +set(PLASMA_SRCS client.cc common.cc malloc.cc plasma.cc protocol.cc ${PLASMA_IO_SRCS}) set(PLASMA_STORE_SRCS dlmalloc.cc - events.cc eviction_policy.cc plasma_allocator.cc - store.cc - thirdparty/ae/ae.c) + ${PLASMA_IO_SRCS} + store.cc) set(PLASMA_LINK_LIBS arrow_shared) set(PLASMA_STATIC_LINK_LIBS arrow_static) @@ -87,6 +91,7 @@ add_arrow_lib(plasma PLASMA_LIBRARIES DEPENDENCIES gen_plasma_fbs + asio_ep SHARED_LINK_LIBS ${PLASMA_LINK_LIBS} STATIC_LINK_LIBS @@ -120,13 +125,16 @@ list(APPEND PLASMA_EXTERNAL_STORE_SOURCES "external_store.cc" "hash_table_store. # We use static libraries for the plasma_store_server executable so that it can # be copied around and used in different locations. add_executable(plasma_store_server ${PLASMA_EXTERNAL_STORE_SOURCES} ${PLASMA_STORE_SRCS}) + if(ARROW_BUILD_STATIC) target_link_libraries(plasma_store_server plasma_static ${PLASMA_STATIC_LINK_LIBS}) else() # Fallback to shared libs in the case that static libraries are not build. target_link_libraries(plasma_store_server plasma_shared ${PLASMA_LINK_LIBS}) endif() + add_dependencies(plasma plasma_store_server) +add_dependencies(plasma_store_server asio_ep) if(ARROW_RPATH_ORIGIN) if(APPLE) @@ -153,11 +161,7 @@ elseif(APPLE) endif() endif() -install(FILES common.h - compat.h - client.h - events.h - test-util.h +install(FILES common.h compat.h client.h test-util.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/plasma") # Plasma store diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index e88c5cadff0..f4e713afb9d 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -23,34 +23,22 @@ #include #endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include // PROT_READ, PROT_WRITE, MAP_SHARED, MAP_FAILED #include #include #include #include #include +#include #include #include "arrow/buffer.h" +#include "arrow/util/logging.h" #include "arrow/util/thread-pool.h" #include "plasma/common.h" -#include "plasma/fling.h" -#include "plasma/io.h" #include "plasma/malloc.h" -#include "plasma/plasma.h" #include "plasma/protocol.h" #ifdef PLASMA_CUDA @@ -68,12 +56,11 @@ using arrow::cuda::CudaDeviceManager; #define XXH64_DEFAULT_SEED 0 -namespace fb = plasma::flatbuf; - namespace plasma { -using fb::MessageType; -using fb::PlasmaError; +using flatbuf::MessageType; +using flatbuf::PlasmaError; +using io::ServerConnection; using arrow::MutableBuffer; @@ -196,6 +183,11 @@ class ClientMmapTableEntry { ARROW_DISALLOW_COPY_AND_ASSIGN(ClientMmapTableEntry); }; +Status PlasmaReceive(const std::shared_ptr& client, + MessageType message_type, std::vector* buffer) { + return client->ReadMessage(static_cast(message_type), buffer); +} + class PlasmaClient::Impl : public std::enable_shared_from_this { public: Impl(); @@ -203,9 +195,7 @@ class PlasmaClient::Impl : public std::enable_shared_from_this* data, int device_num = 0); @@ -235,13 +225,16 @@ class PlasmaClient::Impl : public std::enable_shared_from_thisGetNativeHandle(); + } Status Disconnect(); @@ -286,8 +279,13 @@ class PlasmaClient::Impl : public std::enable_shared_from_this store_conn_; + std::shared_ptr notification_conn_; + /// The name of the socket we are connecting to. + std::string store_socket_name_; /// Table of dlmalloc buffer files that have been memory mapped so far. This /// is a hash table mapping a file descriptor to a struct containing the /// address of the corresponding memory-mapped file. @@ -348,8 +346,9 @@ bool PlasmaClient::Impl::IsInUse(const ObjectID& object_id) { int PlasmaClient::Impl::GetStoreFd(int store_fd) { auto entry = mmap_table_.find(store_fd); if (entry == mmap_table_.end()) { - int fd = recv_fd(store_conn_); - ARROW_CHECK(fd >= 0) << "recv not successful"; + int fd; + auto status = store_conn_->RecvFd(&fd); + ARROW_CHECK(status.ok() && fd >= 0) << "recv not successful"; return fd; } else { return entry->second->fd(); @@ -794,8 +793,7 @@ Status PlasmaClient::Impl::Abort(const ObjectID& object_id) { std::vector buffer; ObjectID id; - MessageType type; - RETURN_NOT_OK(ReadMessage(store_conn_, &type, &buffer)); + RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaAbortReply, &buffer)); return ReadAbortReply(buffer.data(), buffer.size(), &id); } @@ -827,8 +825,7 @@ Status PlasmaClient::Impl::Evict(int64_t num_bytes, int64_t& num_bytes_evicted) RETURN_NOT_OK(SendEvictRequest(store_conn_, num_bytes)); // Wait for a response with the number of bytes actually evicted. std::vector buffer; - MessageType type; - RETURN_NOT_OK(ReadMessage(store_conn_, &type, &buffer)); + RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaEvictReply, &buffer)); return ReadEvictReply(buffer.data(), buffer.size(), num_bytes_evicted); } @@ -847,30 +844,22 @@ Status PlasmaClient::Impl::Hash(const ObjectID& object_id, uint8_t* digest) { return Status::OK(); } -Status PlasmaClient::Impl::Subscribe(int* fd) { - int sock[2]; - // Create a non-blocking socket pair. This will only be used to send - // notifications from the Plasma store to the client. - socketpair(AF_UNIX, SOCK_STREAM, 0, sock); - // Make the socket non-blocking. - int flags = fcntl(sock[1], F_GETFL, 0); - ARROW_CHECK(fcntl(sock[1], F_SETFL, flags | O_NONBLOCK) == 0); +Status PlasmaClient::Impl::Subscribe() { + if (store_socket_name_.empty()) { + ARROW_LOG(FATAL) << "Please connect to the store before subscribing messages."; + } + auto stream = io::CreateLocalStream(io_context_, store_socket_name_); + auto conn = ServerConnection::Create(std::move(stream)); + notification_conn_ = std::move(conn); // Tell the Plasma store about the subscription. - RETURN_NOT_OK(SendSubscribeRequest(store_conn_)); - // Send the file descriptor that the Plasma store should use to push - // notifications about sealed objects to this client. - ARROW_CHECK(send_fd(store_conn_, sock[1]) >= 0); - close(sock[1]); - // Return the file descriptor that the client should use to read notifications - // about sealed objects. - *fd = sock[0]; - return Status::OK(); + return SendSubscribeRequest(notification_conn_); } +// TODO(suquark): Move it to protocol.cc Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* object_id, int64_t* data_size, int64_t* metadata_size) { - auto object_info = flatbuffers::GetRoot(buffer); + auto object_info = flatbuffers::GetRoot(buffer); ARROW_CHECK(object_info->object_id()->size() == sizeof(ObjectID)); memcpy(object_id, object_info->object_id()->data(), sizeof(ObjectID)); if (object_info->is_deletion()) { @@ -883,26 +872,25 @@ Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* o return Status::OK(); } -Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id, - int64_t* data_size, int64_t* metadata_size) { - auto notification = ReadMessageAsync(fd); - if (notification == NULL) { +Status PlasmaClient::Impl::GetNotification(ObjectID* object_id, int64_t* data_size, + int64_t* metadata_size) { + std::unique_ptr notification; + if (!notification_conn_) { + ARROW_LOG(ERROR) << "Get notification without subscription."; + return Status::ExecutionError("Get notification without subscription."); + } + auto status = notification_conn_->ReadNotificationMessage(notification); + if (!status.ok()) { return Status::IOError("Failed to read object notification from Plasma socket"); } return DecodeNotification(notification.get(), object_id, data_size, metadata_size); } -Status PlasmaClient::Impl::Connect(const std::string& store_socket_name, - const std::string& manager_socket_name, - int release_delay, int num_retries) { - RETURN_NOT_OK(ConnectIpcSocketRetry(store_socket_name, num_retries, -1, &store_conn_)); - if (manager_socket_name != "") { - return Status::NotImplemented("plasma manager is no longer supported"); - } - if (release_delay != 0) { - ARROW_LOG(WARNING) << "The release_delay parameter in PlasmaClient::Connect " - << "is deprecated"; - } +Status PlasmaClient::Impl::Connect(const std::string& store_socket_name) { + store_socket_name_ = store_socket_name; + auto stream = io::CreateLocalStream(io_context_, store_socket_name_); + auto conn = ServerConnection::Create(std::move(stream)); + store_conn_ = std::move(conn); // Send a ConnectRequest to the store to get its memory capacity. RETURN_NOT_OK(SendConnectRequest(store_conn_)); std::vector buffer; @@ -918,9 +906,14 @@ Status PlasmaClient::Impl::Disconnect() { // Close the connections to Plasma. The Plasma store will release the objects // that were in use by us when handling the SIGPIPE. - close(store_conn_); - store_conn_ = -1; - return Status::OK(); + if (notification_conn_) { + auto status = notification_conn_->Disconnect(); + if (!status.ok()) { + ARROW_LOG(ERROR) << "Failed to disconnect notification client " + << "(" << status << ")"; + } + } + return store_conn_->Disconnect(); } // ---------------------------------------------------------------------- @@ -933,8 +926,19 @@ PlasmaClient::~PlasmaClient() {} Status PlasmaClient::Connect(const std::string& store_socket_name, const std::string& manager_socket_name, int release_delay, int num_retries) { - return impl_->Connect(store_socket_name, manager_socket_name, release_delay, - num_retries); + // Keep "manager_socket_name" & "release_delay" for compatibility. + if (manager_socket_name != "") { + return Status::NotImplemented("plasma manager is no longer supported"); + } + if (release_delay != 0) { + ARROW_LOG(WARNING) << "The release_delay parameter in PlasmaClient::Connect " + << "is deprecated"; + } + if (num_retries != -1) { + ARROW_LOG(WARNING) << "The num_retries parameter in PlasmaClient::Connect " + << "is deprecated"; + } + return impl_->Connect(store_socket_name); } Status PlasmaClient::Create(const ObjectID& object_id, int64_t data_size, @@ -988,11 +992,11 @@ Status PlasmaClient::Hash(const ObjectID& object_id, uint8_t* digest) { return impl_->Hash(object_id, digest); } -Status PlasmaClient::Subscribe(int* fd) { return impl_->Subscribe(fd); } +Status PlasmaClient::Subscribe() { return impl_->Subscribe(); } -Status PlasmaClient::GetNotification(int fd, ObjectID* object_id, int64_t* data_size, +Status PlasmaClient::GetNotification(ObjectID* object_id, int64_t* data_size, int64_t* metadata_size) { - return impl_->GetNotification(fd, object_id, data_size, metadata_size); + return impl_->GetNotification(object_id, data_size, metadata_size); } Status PlasmaClient::DecodeNotification(const uint8_t* buffer, ObjectID* object_id, @@ -1000,6 +1004,10 @@ Status PlasmaClient::DecodeNotification(const uint8_t* buffer, ObjectID* object_ return impl_->DecodeNotification(buffer, object_id, data_size, metadata_size); } +int PlasmaClient::GetNativeNotificationHandle() { + return impl_->GetNativeNotificationHandle(); +} + Status PlasmaClient::Disconnect() { return impl_->Disconnect(); } bool PlasmaClient::IsInUse(const ObjectID& object_id) { diff --git a/cpp/src/plasma/client.h b/cpp/src/plasma/client.h index facfd37ca78..f35ada95b6c 100644 --- a/cpp/src/plasma/client.h +++ b/cpp/src/plasma/client.h @@ -59,7 +59,8 @@ class ARROW_EXPORT PlasmaClient { /// Note that plasma manager is no longer supported, this function /// will return failure if this is not "". /// \param release_delay Deprecated (not used). - /// \param num_retries number of attempts to connect to IPC socket, default 50 + /// \param num_retries number of attempts to connect to IPC socket, + /// default 50. Deprecated (not used). /// \return The return status. Status Connect(const std::string& store_socket_name, const std::string& manager_socket_name = "", int release_delay = 0, @@ -225,21 +226,19 @@ class ARROW_EXPORT PlasmaClient { /// Whenever an object is sealed, a message will be written to the client /// socket that is returned by this method. /// - /// \param fd Out parameter for the file descriptor the client should use to - /// read notifications - /// from the object store about sealed objects. /// \return The return status. - Status Subscribe(int* fd); + Status Subscribe(); + + /// Return the native handle of the notification client. + int GetNativeNotificationHandle(); /// Receive next object notification for this client if Subscribe has been called. /// - /// \param fd The file descriptor we are reading the notification from. /// \param object_id Out parameter, the object_id of the object that was sealed. /// \param data_size Out parameter, the data size of the object that was sealed. /// \param metadata_size Out parameter, the metadata size of the object that was sealed. /// \return The return status. - Status GetNotification(int fd, ObjectID* object_id, int64_t* data_size, - int64_t* metadata_size); + Status GetNotification(ObjectID* object_id, int64_t* data_size, int64_t* metadata_size); Status DecodeNotification(const uint8_t* buffer, ObjectID* object_id, int64_t* data_size, int64_t* metadata_size); diff --git a/cpp/src/plasma/common.cc b/cpp/src/plasma/common.cc index 490aa158b33..de4dc358cef 100644 --- a/cpp/src/plasma/common.cc +++ b/cpp/src/plasma/common.cc @@ -19,10 +19,6 @@ #include -#include "plasma/plasma_generated.h" - -namespace fb = plasma::flatbuf; - namespace plasma { using arrow::Status; diff --git a/cpp/src/plasma/dlmalloc.cc b/cpp/src/plasma/dlmalloc.cc index 463e967e036..6acbf46e705 100644 --- a/cpp/src/plasma/dlmalloc.cc +++ b/cpp/src/plasma/dlmalloc.cc @@ -29,6 +29,8 @@ #include #include +#include "arrow/util/logging.h" + #include "plasma/common.h" #include "plasma/plasma.h" diff --git a/cpp/src/plasma/events.cc b/cpp/src/plasma/events.cc deleted file mode 100644 index 28ff1267545..00000000000 --- a/cpp/src/plasma/events.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "plasma/events.h" - -#include - -#include - -extern "C" { -#include "plasma/thirdparty/ae/ae.h" -} - -namespace plasma { - -// Verify that the constants defined in events.h are defined correctly. -static_assert(kEventLoopTimerDone == AE_NOMORE, "constant defined incorrectly"); -static_assert(kEventLoopOk == AE_OK, "constant defined incorrectly"); -static_assert(kEventLoopRead == AE_READABLE, "constant defined incorrectly"); -static_assert(kEventLoopWrite == AE_WRITABLE, "constant defined incorrectly"); - -void EventLoop::FileEventCallback(aeEventLoop* loop, int fd, void* context, int events) { - FileCallback* callback = reinterpret_cast(context); - (*callback)(events); -} - -int EventLoop::TimerEventCallback(aeEventLoop* loop, TimerID timer_id, void* context) { - TimerCallback* callback = reinterpret_cast(context); - return (*callback)(timer_id); -} - -constexpr int kInitialEventLoopSize = 1024; - -EventLoop::EventLoop() { loop_ = aeCreateEventLoop(kInitialEventLoopSize); } - -bool EventLoop::AddFileEvent(int fd, int events, const FileCallback& callback) { - if (file_callbacks_.find(fd) != file_callbacks_.end()) { - return false; - } - auto data = std::unique_ptr(new FileCallback(callback)); - void* context = reinterpret_cast(data.get()); - // Try to add the file descriptor. - int err = aeCreateFileEvent(loop_, fd, events, EventLoop::FileEventCallback, context); - // If it cannot be added, increase the size of the event loop. - if (err == AE_ERR && errno == ERANGE) { - err = aeResizeSetSize(loop_, 3 * aeGetSetSize(loop_) / 2); - if (err != AE_OK) { - return false; - } - err = aeCreateFileEvent(loop_, fd, events, EventLoop::FileEventCallback, context); - } - // In any case, test if there were errors. - if (err == AE_OK) { - file_callbacks_.emplace(fd, std::move(data)); - return true; - } - return false; -} - -void EventLoop::RemoveFileEvent(int fd) { - aeDeleteFileEvent(loop_, fd, AE_READABLE | AE_WRITABLE); - file_callbacks_.erase(fd); -} - -void EventLoop::Start() { aeMain(loop_); } - -void EventLoop::Stop() { aeStop(loop_); } - -void EventLoop::Shutdown() { - if (loop_ != nullptr) { - aeDeleteEventLoop(loop_); - loop_ = nullptr; - } -} - -EventLoop::~EventLoop() { Shutdown(); } - -int64_t EventLoop::AddTimer(int64_t timeout, const TimerCallback& callback) { - auto data = std::unique_ptr(new TimerCallback(callback)); - void* context = reinterpret_cast(data.get()); - int64_t timer_id = - aeCreateTimeEvent(loop_, timeout, EventLoop::TimerEventCallback, context, NULL); - timer_callbacks_.emplace(timer_id, std::move(data)); - return timer_id; -} - -int EventLoop::RemoveTimer(int64_t timer_id) { - int err = aeDeleteTimeEvent(loop_, timer_id); - timer_callbacks_.erase(timer_id); - return err; -} - -} // namespace plasma diff --git a/cpp/src/plasma/events.h b/cpp/src/plasma/events.h deleted file mode 100644 index 765be9c01fb..00000000000 --- a/cpp/src/plasma/events.h +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#ifndef PLASMA_EVENTS -#define PLASMA_EVENTS - -#include -#include -#include - -struct aeEventLoop; - -namespace plasma { - -// The constants below are defined using hardcoded values taken from ae.h so -// that ae.h does not need to be included in this file. - -/// Constant specifying that the timer is done and it will be removed. -constexpr int kEventLoopTimerDone = -1; // AE_NOMORE - -/// A successful status. -constexpr int kEventLoopOk = 0; // AE_OK - -/// Read event on the file descriptor. -constexpr int kEventLoopRead = 1; // AE_READABLE - -/// Write event on the file descriptor. -constexpr int kEventLoopWrite = 2; // AE_WRITABLE - -typedef long long TimerID; // NOLINT - -class EventLoop { - public: - // Signature of the handler that will be called when there is a new event - // on the file descriptor that this handler has been registered for. - // - // The arguments are the event flags (read or write). - using FileCallback = std::function; - - // This handler will be called when a timer times out. The timer id is - // passed as an argument. The return is the number of milliseconds the timer - // shall be reset to or kEventLoopTimerDone if the timer shall not be - // triggered again. - using TimerCallback = std::function; - - EventLoop(); - - ~EventLoop(); - - /// Add a new file event handler to the event loop. - /// - /// \param fd The file descriptor we are listening to. - /// \param events The flags for events we are listening to (read or write). - /// \param callback The callback that will be called when the event happens. - /// \return Returns true if the event handler was added successfully. - bool AddFileEvent(int fd, int events, const FileCallback& callback); - - /// Remove a file event handler from the event loop. - /// - /// \param fd The file descriptor of the event handler. - void RemoveFileEvent(int fd); - - /// Register a handler that will be called after a time slice of - /// "timeout" milliseconds. - /// - /// \param timeout The timeout in milliseconds. - /// \param callback The callback for the timeout. - /// \return The ID of the newly created timer. - int64_t AddTimer(int64_t timeout, const TimerCallback& callback); - - /// Remove a timer handler from the event loop. - /// - /// \param timer_id The ID of the timer that is to be removed. - /// \return The ae.c error code. TODO(pcm): needs to be standardized - int RemoveTimer(int64_t timer_id); - - /// \brief Run the event loop. - void Start(); - - /// \brief Stop the event loop - void Stop(); - - void Shutdown(); - - private: - static void FileEventCallback(aeEventLoop* loop, int fd, void* context, int events); - - static int TimerEventCallback(aeEventLoop* loop, TimerID timer_id, void* context); - - aeEventLoop* loop_; - std::unordered_map> file_callbacks_; - std::unordered_map> timer_callbacks_; -}; - -} // namespace plasma - -#endif // PLASMA_EVENTS diff --git a/cpp/src/plasma/eviction_policy.cc b/cpp/src/plasma/eviction_policy.cc index da5df5a36dd..c4fe9c9fea7 100644 --- a/cpp/src/plasma/eviction_policy.cc +++ b/cpp/src/plasma/eviction_policy.cc @@ -16,6 +16,7 @@ // under the License. #include "plasma/eviction_policy.h" +#include "arrow/util/logging.h" #include "plasma/plasma_allocator.h" #include diff --git a/cpp/src/plasma/eviction_policy.h b/cpp/src/plasma/eviction_policy.h index 68342ae102f..2f04be7898d 100644 --- a/cpp/src/plasma/eviction_policy.h +++ b/cpp/src/plasma/eviction_policy.h @@ -60,7 +60,7 @@ class EvictionPolicy { public: /// Construct an eviction policy. /// - /// @param store_info Information about the Plasma store that is exposed + /// \param store_info Information about the Plasma store that is exposed /// to the eviction policy. explicit EvictionPolicy(PlasmaStoreInfo* store_info); @@ -69,7 +69,7 @@ class EvictionPolicy { /// store calls begin_object_access, we can remove the object from the LRU /// cache. /// - /// @param object_id The object ID of the object that was created. + /// \param object_id The object ID of the object that was created. void ObjectCreated(const ObjectID& object_id); /// This method will be called when the Plasma store needs more space, perhaps @@ -77,11 +77,11 @@ class EvictionPolicy { /// policy will assume that the objects chosen to be evicted will in fact be /// evicted from the Plasma store by the caller. /// - /// @param size The size in bytes of the new object, including both data and + /// \param size The size in bytes of the new object, including both data and /// metadata. - /// @param objects_to_evict The object IDs that were chosen for eviction will + /// \param objects_to_evict The object IDs that were chosen for eviction will /// be stored into this vector. - /// @return True if enough space can be freed and false otherwise. + /// \return True if enough space can be freed and false otherwise. bool RequireSpace(int64_t size, std::vector* objects_to_evict); /// This method will be called whenever an unused object in the Plasma store @@ -89,8 +89,8 @@ class EvictionPolicy { /// assume that the objects chosen to be evicted will in fact be evicted from /// the Plasma store by the caller. /// - /// @param object_id The ID of the object that is now being used. - /// @param objects_to_evict The object IDs that were chosen for eviction will + /// \param object_id The ID of the object that is now being used. + /// \param objects_to_evict The object IDs that were chosen for eviction will /// be stored into this vector. void BeginObjectAccess(const ObjectID& object_id, std::vector* objects_to_evict); @@ -100,8 +100,8 @@ class EvictionPolicy { /// eviction policy will assume that the objects chosen to be evicted will in /// fact be evicted from the Plasma store by the caller. /// - /// @param object_id The ID of the object that is no longer being used. - /// @param objects_to_evict The object IDs that were chosen for eviction will + /// \param object_id The ID of the object that is no longer being used. + /// \param objects_to_evict The object IDs that were chosen for eviction will /// be stored into this vector. void EndObjectAccess(const ObjectID& object_id, std::vector* objects_to_evict); @@ -113,16 +113,16 @@ class EvictionPolicy { /// @note This method is not part of the API. It is exposed in the header file /// only for testing. /// - /// @param num_bytes_required The number of bytes of space to try to free up. - /// @param objects_to_evict The object IDs that were chosen for eviction will + /// \param num_bytes_required The number of bytes of space to try to free up. + /// \param objects_to_evict The object IDs that were chosen for eviction will /// be stored into this vector. - /// @return The total number of bytes of space chosen to be evicted. + /// \return The total number of bytes of space chosen to be evicted. int64_t ChooseObjectsToEvict(int64_t num_bytes_required, std::vector* objects_to_evict); /// This method will be called when an object is going to be removed /// - /// @param object_id The ID of the object that is now being used. + /// \param object_id The ID of the object that is now being used. void RemoveObject(const ObjectID& object_id); private: diff --git a/cpp/src/plasma/fling.h b/cpp/src/plasma/fling.h index 78ac9d17f26..f05803db690 100644 --- a/cpp/src/plasma/fling.h +++ b/cpp/src/plasma/fling.h @@ -40,13 +40,13 @@ void init_msg(struct msghdr* msg, struct iovec* iov, char* buf, size_t buf_len); // Send a file descriptor over a unix domain socket. // -// @param conn Unix domain socket to send the file descriptor over. -// @param fd File descriptor to send over. -// @return Status code which is < 0 on failure. +// \param conn Unix domain socket to send the file descriptor over. +// \param fd File descriptor to send over. +// \return Status code which is < 0 on failure. int send_fd(int conn, int fd); // Receive a file descriptor over a unix domain socket. // -// @param conn Unix domain socket to receive the file descriptor from. -// @return File descriptor or a value < 0 on failure. +// \param conn Unix domain socket to receive the file descriptor from. +// \return File descriptor or a value < 0 on failure. int recv_fd(int conn); diff --git a/cpp/src/plasma/io.cc b/cpp/src/plasma/io.cc deleted file mode 100644 index ba5f2551919..00000000000 --- a/cpp/src/plasma/io.cc +++ /dev/null @@ -1,241 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "plasma/io.h" - -#include -#include -#include - -#include "arrow/status.h" -#include "arrow/util/logging.h" - -#include "plasma/common.h" -#include "plasma/plasma_generated.h" - -using arrow::Status; - -/// Number of times we try connecting to a socket. -constexpr int64_t kNumConnectAttempts = 20; -/// Time to wait between connection attempts to a socket. -constexpr int64_t kConnectTimeoutMs = 400; - -namespace plasma { - -using flatbuf::MessageType; - -Status WriteBytes(int fd, uint8_t* cursor, size_t length) { - ssize_t nbytes = 0; - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - // While we haven't written the whole message, write to the file descriptor, - // advance the cursor, and decrease the amount left to write. - nbytes = write(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return Status::IOError(strerror(errno)); - } else if (nbytes == 0) { - return Status::IOError("Encountered unexpected EOF"); - } - ARROW_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return Status::OK(); -} - -Status WriteMessage(int fd, MessageType type, int64_t length, uint8_t* bytes) { - int64_t version = kPlasmaProtocolVersion; - RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast(&version), sizeof(version))); - RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast(&type), sizeof(type))); - RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast(&length), sizeof(length))); - return WriteBytes(fd, bytes, length * sizeof(char)); -} - -Status ReadBytes(int fd, uint8_t* cursor, size_t length) { - ssize_t nbytes = 0; - // Termination condition: EOF or read 'length' bytes total. - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - nbytes = read(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return Status::IOError(strerror(errno)); - } else if (0 == nbytes) { - return Status::IOError("Encountered unexpected EOF"); - } - ARROW_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return Status::OK(); -} - -Status ReadMessage(int fd, MessageType* type, std::vector* buffer) { - int64_t version; - RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast(&version), sizeof(version)), - *type = MessageType::PlasmaDisconnectClient); - ARROW_CHECK(version == kPlasmaProtocolVersion) << "version = " << version; - RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast(type), sizeof(*type)), - *type = MessageType::PlasmaDisconnectClient); - int64_t length_temp; - RETURN_NOT_OK_ELSE( - ReadBytes(fd, reinterpret_cast(&length_temp), sizeof(length_temp)), - *type = MessageType::PlasmaDisconnectClient); - // The length must be read as an int64_t, but it should be used as a size_t. - size_t length = static_cast(length_temp); - if (length > buffer->size()) { - buffer->resize(length); - } - RETURN_NOT_OK_ELSE(ReadBytes(fd, buffer->data(), length), - *type = MessageType::PlasmaDisconnectClient); - return Status::OK(); -} - -int BindIpcSock(const std::string& pathname, bool shall_listen) { - struct sockaddr_un socket_address; - int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname; - return -1; - } - // Tell the system to allow the port to be reused. - int on = 1; - if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&on), - sizeof(on)) < 0) { - ARROW_LOG(ERROR) << "setsockopt failed for pathname " << pathname; - close(socket_fd); - return -1; - } - - unlink(pathname.c_str()); - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (pathname.size() + 1 > sizeof(socket_address.sun_path)) { - ARROW_LOG(ERROR) << "Socket pathname is too long."; - close(socket_fd); - return -1; - } - strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1); - - if (bind(socket_fd, reinterpret_cast(&socket_address), - sizeof(socket_address)) != 0) { - ARROW_LOG(ERROR) << "Bind failed for pathname " << pathname; - close(socket_fd); - return -1; - } - if (shall_listen && listen(socket_fd, 128) == -1) { - ARROW_LOG(ERROR) << "Could not listen to socket " << pathname; - close(socket_fd); - return -1; - } - return socket_fd; -} - -Status ConnectIpcSocketRetry(const std::string& pathname, int num_retries, - int64_t timeout, int* fd) { - // Pick the default values if the user did not specify. - if (num_retries < 0) { - num_retries = kNumConnectAttempts; - } - if (timeout < 0) { - timeout = kConnectTimeoutMs; - } - *fd = ConnectIpcSock(pathname); - while (*fd < 0 && num_retries > 0) { - ARROW_LOG(ERROR) << "Connection to IPC socket failed for pathname " << pathname - << ", retrying " << num_retries << " more times"; - // Sleep for timeout milliseconds. - usleep(static_cast(timeout * 1000)); - *fd = ConnectIpcSock(pathname); - --num_retries; - } - - // If we could not connect to the socket, exit. - if (*fd == -1) { - return Status::IOError("Could not connect to socket ", pathname); - } - - return Status::OK(); -} - -int ConnectIpcSock(const std::string& pathname) { - struct sockaddr_un socket_address; - int socket_fd; - - socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname; - return -1; - } - - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (pathname.size() + 1 > sizeof(socket_address.sun_path)) { - ARROW_LOG(ERROR) << "Socket pathname is too long."; - close(socket_fd); - return -1; - } - strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1); - - if (connect(socket_fd, reinterpret_cast(&socket_address), - sizeof(socket_address)) != 0) { - close(socket_fd); - return -1; - } - - return socket_fd; -} - -int AcceptClient(int socket_fd) { - int client_fd = accept(socket_fd, NULL, NULL); - if (client_fd < 0) { - ARROW_LOG(ERROR) << "Error reading from socket."; - return -1; - } - return client_fd; -} - -std::unique_ptr ReadMessageAsync(int sock) { - int64_t size; - Status s = ReadBytes(sock, reinterpret_cast(&size), sizeof(int64_t)); - if (!s.ok()) { - // The other side has closed the socket. - ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred."; - close(sock); - return NULL; - } - auto message = std::unique_ptr(new uint8_t[size]); - s = ReadBytes(sock, message.get(), size); - if (!s.ok()) { - // The other side has closed the socket. - ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred."; - close(sock); - return NULL; - } - return message; -} - -} // namespace plasma diff --git a/cpp/src/plasma/io.h b/cpp/src/plasma/io.h deleted file mode 100644 index 745518ab227..00000000000 --- a/cpp/src/plasma/io.h +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#ifndef PLASMA_IO_H -#define PLASMA_IO_H - -#include -#include -#include -#include - -#include -#include -#include - -#include "arrow/status.h" -#include "plasma/compat.h" - -namespace plasma { - -namespace flatbuf { - -// Forward declaration outside the namespace, which is defined in plasma_generated.h. -enum class MessageType : int64_t; - -} // namespace flatbuf - -// TODO(pcm): Replace our own custom message header (message type, -// message length, plasma protocol verion) with one that is serialized -// using flatbuffers. -constexpr int64_t kPlasmaProtocolVersion = 0x0000000000000000; - -using arrow::Status; - -Status WriteBytes(int fd, uint8_t* cursor, size_t length); - -Status WriteMessage(int fd, flatbuf::MessageType type, int64_t length, uint8_t* bytes); - -Status ReadBytes(int fd, uint8_t* cursor, size_t length); - -Status ReadMessage(int fd, flatbuf::MessageType* type, std::vector* buffer); - -int BindIpcSock(const std::string& pathname, bool shall_listen); - -int ConnectIpcSock(const std::string& pathname); - -Status ConnectIpcSocketRetry(const std::string& pathname, int num_retries, - int64_t timeout, int* fd); - -int AcceptClient(int socket_fd); - -std::unique_ptr ReadMessageAsync(int sock); - -} // namespace plasma - -#endif // PLASMA_IO_H diff --git a/cpp/src/plasma/io/basic_connection.cc b/cpp/src/plasma/io/basic_connection.cc new file mode 100644 index 00000000000..ce5ea607910 --- /dev/null +++ b/cpp/src/plasma/io/basic_connection.cc @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "plasma/io/basic_connection.h" +#include "arrow/util/logging.h" + +#include +#include +#include +#include + +namespace plasma { +namespace io { + +/// Connect a Unix local domain socket. +/// +/// \param socket The socket to connect. +/// \param socket_name The name/path of the socket. +/// \return Status. +std::error_code UnixDomainSocketConnect(asio::local::stream_protocol::socket& socket, + const std::string& socket_name) { + asio::local::stream_protocol::endpoint endpoint(socket_name); + std::error_code ec; + socket.connect(endpoint, ec); + if (ec) { + // Close the socket if the connect failed. + std::error_code close_error; + socket.close(close_error); + } + return ec; +} + +PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& name) { + // TODO(suquark): May be use "kNumConnectAttempts" and "kConnectTimeoutMs"? + constexpr int num_retries = 50; + constexpr int timeout_ms = 100; + ARROW_CHECK(!name.empty()); +#ifndef _WIN32 + asio::basic_stream_socket socket(io_context); + for (int i = 0; i < num_retries; i++) { + std::error_code ec = UnixDomainSocketConnect(socket, name); + if (!ec) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(timeout_ms)); + if (i > 0) { + ARROW_LOG(ERROR) << "Retrying to connect to socket for pathname " << name + << " (num_attempts = " << i << ", num_retries = " << num_retries + << ")"; + } + } + return socket; +#else +// For windows: https://stackoverflow.com/questions/1236460/c-using-windows-named-pipes +#error "Windows has not been supported." +#endif +} + +PlasmaAcceptor CreateLocalAcceptor(asio::io_context& io_context, + const std::string& name) { +#ifndef _WIN32 + return PlasmaAcceptor(io_context, asio::local::stream_protocol::endpoint(name)); +#else +// For windows: https://stackoverflow.com/questions/1236460/c-using-windows-named-pipes +#error "Windows has not been supported." +#endif +} + +template +Connection::Connection(T&& stream) + : stream_(std::move(stream)), + async_write_in_flight_(false), + async_write_max_messages_(1), + async_write_queue_() {} + +template +Connection::~Connection() { + // If there are any pending messages, invoke their callbacks with an IOError status. + for (const auto& write_buffer : async_write_queue_) { + write_buffer->handler( + std::error_code(static_cast(std::errc::io_error), std::system_category())); + } +} + +template +std::error_code Connection::ReadBuffer(const asio::mutable_buffer& buffer) { + std::error_code ec; + // Loop until all bytes are read while handling interrupts. + uint64_t bytes_remaining = asio::buffer_size(buffer); + uint64_t position = 0; + while (bytes_remaining != 0) { + size_t bytes_read = + stream_.read_some(asio::buffer(buffer + position, bytes_remaining), ec); + position += bytes_read; + bytes_remaining -= bytes_read; + if (ec.value() == EINTR) { + continue; + } else if (ec) { + return ec; + } + } + return std::error_code(); +} + +template +std::error_code Connection::ReadBuffer( + const std::vector& buffer) { + // Loop until all bytes are read while handling interrupts. + for (const auto& b : buffer) { + auto ec = ReadBuffer(b); + if (ec) return ec; + } + return std::error_code(); +} + +/// Write a buffer to this connection. +/// +/// \param buffer The buffer. +template +std::error_code Connection::WriteBuffer(const asio::const_buffer& buffer) { + std::error_code error; + // Loop until all bytes are written while handling interrupts. + // When profiling with pprof, unhandled interrupts were being sent by the profiler to + // the raylet process, which was causing synchronous reads and writes to fail. + uint64_t bytes_remaining = asio::buffer_size(buffer); + uint64_t position = 0; + while (bytes_remaining != 0) { + size_t bytes_written = + stream_.write_some(asio::buffer(buffer + position, bytes_remaining), error); + position += bytes_written; + bytes_remaining -= bytes_written; + if (error.value() == EINTR) { + continue; + } else if (error) { + return error; + } + } + return std::error_code(); +} + +template +std::error_code Connection::WriteBuffer( + const std::vector& buffer) { + std::error_code error; + // Loop until all bytes are written while handling interrupts. + // When profiling with pprof, unhandled interrupts were being sent by the profiler to + // the raylet process, which was causing synchronous reads and writes to fail. + for (const auto& b : buffer) { + error = WriteBuffer(b); + if (error) { + return error; + } + } + return std::error_code(); +} + +template +void Connection::WriteBufferAsync(std::unique_ptr write_buffer) { + async_writes_ += 1; + auto size = async_write_queue_.size(); + auto size_is_power_of_two = (size & (size - 1)) == 0; + if (size > 1000 && size_is_power_of_two) { + ARROW_LOG(WARNING) << "Connection has " << size << " buffered async writes"; + } + async_write_queue_.push_back(std::move(write_buffer)); + if (!async_write_in_flight_) { + DoAsyncWrites(); + } +} + +// Shuts down socket for this connection. +template +void Connection::Close() { + std::error_code ec; + stream_.close(ec); +} + +template +std::string Connection::DebugString() const { + std::stringstream result; + result << "\n- bytes read: " << bytes_read_; + result << "\n- bytes written: " << bytes_written_; + result << "\n- num async writes: " << async_writes_; + result << "\n- num sync writes: " << sync_writes_; + result << "\n- writing: " << async_write_in_flight_; + result << "\n- pending async messages: " << async_write_queue_.size(); + return result.str(); +} + +template +void Connection::DoAsyncWrites() { + // Make sure we were not writing to the socket. + ARROW_CHECK(!async_write_in_flight_); + async_write_in_flight_ = true; + + // Do an async write of everything currently in the queue to the socket. + std::vector message_buffers; + int num_messages = 0; + for (const auto& write_buffer : async_write_queue_) { + write_buffer->ToBuffers(message_buffers); + num_messages++; + if (num_messages >= async_write_max_messages_) { + break; + } + } + + // Ensure lambda holds a reference to this. + auto this_ptr = this->shared_from_this(); + asio::async_write(stream_, message_buffers, + [this, this_ptr, num_messages](const std::error_code& ec, + size_t bytes_transferred) { + bytes_written_ += bytes_transferred; + // Call the handlers for the written messages. + for (int i = 0; i < num_messages; i++) { + auto write_buffer = std::move(async_write_queue_.front()); + write_buffer->handler(ec); + async_write_queue_.pop_front(); // release object + } + // We finished writing, so mark that we're no longer doing an + // async write. + async_write_in_flight_ = false; + // If there is more to write, try to write the rest. + if (!async_write_queue_.empty()) { + DoAsyncWrites(); + } + }); +} + +// We have to fill the template of all possible types. +template class Connection; + +} // namespace io +} // namespace plasma diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h new file mode 100644 index 00000000000..32f8445e867 --- /dev/null +++ b/cpp/src/plasma/io/basic_connection.h @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef PLASMA_IO_BASIC_CONNECTION_H +#define PLASMA_IO_BASIC_CONNECTION_H + +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif + +#include +#include +#include +#include +#include +#include + +#include "asio.hpp" // NOLINT + +namespace plasma { +namespace io { + +using AsyncWriteCallback = std::function; +// TODO(suquark): Change it according to the platform. +using PlasmaStream = asio::basic_stream_socket; +using PlasmaAcceptor = asio::local::stream_protocol::acceptor; + +/// Create a local acceptor depends on the platform. +PlasmaAcceptor CreateLocalAcceptor(asio::io_context& io_context, const std::string& name); + +/// Create a local stream depends on the platform. +PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& name); + +/// A message that is queued for writing asynchronously. +struct AsyncWriteBuffer { + virtual void ToBuffers(std::vector& message_buffers) {} + AsyncWriteCallback handler; + virtual ~AsyncWriteBuffer() {} +}; + +template +class Connection : public std::enable_shared_from_this> { + public: + explicit Connection(T&& stream); + + ~Connection(); + + /// Read a buffer from this connection. + /// + /// \param buffer The output buffer. + std::error_code ReadBuffer(const asio::mutable_buffer& buffer); + + /// Read buffers from this connection. + /// + /// \param buffer The output vector of buffers. + std::error_code ReadBuffer(const std::vector& buffer); + + /// Write a buffer to this connection. + /// + /// \param buffer The buffer. + std::error_code WriteBuffer(const asio::const_buffer& buffer); + + /// Write buffers to this connection. + /// + /// \param buffer The vector of buffers. + std::error_code WriteBuffer(const std::vector& buffer); + + /// Write buffers to this connection async. + /// + /// \param write_buffer The buffer to write async. + void WriteBufferAsync(std::unique_ptr write_buffer); + + /// Whether the stream is open. + inline bool IsOpen() { return stream_.is_open(); } + + /// Shuts down the stream for this connection. + void Close(); + + /// Get the native handle from the stream. + inline int GetNativeHandle() { return stream_.native_handle(); } + + /// Get the debug string. + std::string DebugString() const; + + protected: + /// The stream that supports most asio protocols (read, read_some, write, + /// write_some, async_read, async_write, async_read_some, async_write_some). + T stream_; + + /// Whether we are in the middle of an async write. + bool async_write_in_flight_; + + /// Max number of messages to write out at once. + const int async_write_max_messages_; + + /// List of pending messages to write. + std::deque> async_write_queue_; + + /// Count of sync messages sent total. + int64_t sync_writes_ = 0; + + /// Count of async messages sent total. + int64_t async_writes_ = 0; + + /// Count of bytes sent total. + int64_t bytes_written_ = 0; + + /// Count of bytes read total. + int64_t bytes_read_ = 0; + + private: + /// Asynchronously flushes the write queue. While async writes are running, the flag + /// async_write_in_flight_ will be set. This should only be called when no async writes + /// are currently in flight. + void DoAsyncWrites(); +}; + +} // namespace io +} // namespace plasma + +#endif // PLASMA_IO_BASIC_CONNECTION_H diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc new file mode 100644 index 00000000000..64a42b376bc --- /dev/null +++ b/cpp/src/plasma/io/connection.cc @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "plasma/io/connection.h" + +#include +#include +#include +#include + +#include "arrow/util/logging.h" +#include "plasma/fling.h" +#include "plasma/plasma_generated.h" +#include "plasma/protocol.h" + +// TODO(pcm): Replace our own custom message header (message type, +// message length, plasma protocol verion) with one that is serialized +// using flatbuffers. +constexpr int64_t kPlasmaProtocolVersion = 0x0000000000000000; + +namespace plasma { +namespace io { + +using flatbuf::MessageType; + +Status asio_to_arrow_status(const std::error_code& ec) { + if (!ec) { + return Status::OK(); + } + if (ec.value() == EPIPE || ec.value() == EBADF || ec.value() == ECONNRESET) { + ARROW_LOG(WARNING) << "Received SIGPIPE, BAD FILE DESCRIPTOR, or ECONNRESET when " + "processing a message. The client on the other end may " + "have hung up."; + } + return Status::IOError("Error code = ", strerror(ec.value())); +} + +struct AsyncMessageWriteBuffer : public AsyncWriteBuffer { + AsyncMessageWriteBuffer(int64_t version, int64_t type, int64_t length, + const uint8_t* message, AsyncWriteCallback callback) + : write_version(version), write_type(type), write_length(length) { + write_message.resize(length); + write_message.assign(message, message + length); + handler = callback; + } + + void ToBuffers(std::vector& message_buffers) override { + message_buffers.push_back(asio::buffer(&write_version, sizeof(write_version))); + message_buffers.push_back(asio::buffer(&write_type, sizeof(write_type))); + message_buffers.push_back(asio::buffer(&write_length, sizeof(write_length))); + message_buffers.push_back(asio::buffer(write_message)); + } + + int64_t write_version; + int64_t write_type; + uint64_t write_length; + std::vector write_message; +}; + +std::shared_ptr ServerConnection::shared_from_this() { + return std::static_pointer_cast(PlasmaConnection::shared_from_this()); +} + +std::shared_ptr ServerConnection::Create(PlasmaStream&& stream) { + std::shared_ptr self(new ServerConnection(std::move(stream))); + return self; +} + +Status ServerConnection::ReadMessage(int64_t type, std::vector* message) { + int64_t read_version, read_type, read_length; + // Wait for a message header from the client. The message header includes the + // protocol version, the message type, and the length of the message. + std::vector header; + header.push_back(asio::buffer(&read_version, sizeof(read_version))); + header.push_back(asio::buffer(&read_type, sizeof(read_type))); + header.push_back(asio::buffer(&read_length, sizeof(read_length))); + + auto ec = PlasmaConnection::ReadBuffer(header); + if (ec) { + return asio_to_arrow_status(ec); + } + // If there was no error, make sure the protocol version matches. + if (read_version != kPlasmaProtocolVersion) { + return Status::ProtocolError( + "Expected Plasma message protocol version: ", kPlasmaProtocolVersion, + ", got protocol version: ", read_version); + } + if (type != read_type) { + if (read_type == static_cast(MessageType::PlasmaDisconnectClient)) { + // Disconnected by client. + return Status::IOError("The other side disconnected."); + } + return Status::IOError("Connection corrupted. Expected message type: ", type, + "; got message type: ", read_type, + ". Check logs or dmesg for previous errors."); + } + // Create read buffer. + message->resize(read_length); + auto buffer = asio::buffer(*message); + // Wait for the message to be read. + return asio_to_arrow_status(PlasmaConnection::ReadBuffer(buffer)); +} + +Status ServerConnection::WriteMessage(int64_t type, int64_t length, + const uint8_t* message) { + PlasmaConnection::sync_writes_ += 1; + PlasmaConnection::bytes_written_ += length; + + std::vector message_buffers; + auto write_version = kPlasmaProtocolVersion; + message_buffers.push_back(asio::buffer(&write_version, sizeof(write_version))); + message_buffers.push_back(asio::buffer(&type, sizeof(type))); + message_buffers.push_back(asio::buffer(&length, sizeof(length))); + message_buffers.push_back(asio::buffer(message, length)); + return asio_to_arrow_status(PlasmaConnection::WriteBuffer(message_buffers)); +} + +void ServerConnection::WriteMessageAsync(int64_t type, int64_t length, + const uint8_t* message, + const AsyncWriteCallback& handler) { + auto write_buffer = std::unique_ptr(new AsyncMessageWriteBuffer( + kPlasmaProtocolVersion, type, length, message, handler)); + PlasmaConnection::WriteBufferAsync(std::move(write_buffer)); +} + +Status ServerConnection::RecvFd(int* fd) { + *fd = recv_fd(GetNativeHandle()); + ARROW_CHECK(*fd); + return Status::OK(); +} + +Status ServerConnection::Disconnect() { + if (!IsOpen()) { + ARROW_LOG(WARNING) << "The client is not connected. 'Disconnect()' is ignored."; + return Status::OK(); + } + // Write the disconnection message. + auto status = + WriteMessage(static_cast(MessageType::PlasmaDisconnectClient), 0, NULLPTR); + Close(); // Close the stream anyway. + return status; +} + +Status ServerConnection::ReadNotificationMessage(std::unique_ptr& message) { + int64_t size; + auto ec = ReadBuffer(asio::mutable_buffer(&size, sizeof(size))); + if (ec) { + // The other side has closed the socket. + ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred."; + return asio_to_arrow_status(ec); + } + message.reset(new uint8_t[size]); + ec = ReadBuffer(asio::mutable_buffer(message.get(), size)); + if (ec) { + // The other side has closed the socket. + ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred."; + return asio_to_arrow_status(ec); + } + return Status::OK(); +} + +ServerConnection::ServerConnection(PlasmaStream&& stream) + : PlasmaConnection(std::move(stream)) {} + +std::shared_ptr ClientConnection::Create( + PlasmaStream&& stream, MessageHandler& message_handler, + const std::string& debug_label) { + return std::shared_ptr( + new ClientConnection(std::move(stream), message_handler, debug_label)); +} + +ClientConnection::ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler, + const std::string& debug_label) + : ServerConnection(std::move(stream)), + debug_label_(debug_label), + message_handler_(message_handler) {} + +std::shared_ptr ClientConnection::shared_from_this() { + return std::static_pointer_cast(ServerConnection::shared_from_this()); +} + +void ClientConnection::ProcessMessages() { + // Wait for a message header from the client. The message header includes the + // protocol version, the message type, and the length of the message. + std::vector header{ + asio::buffer(&read_version_, sizeof(read_version_)), + asio::buffer(&read_type_, sizeof(read_type_)), + asio::buffer(&read_length_, sizeof(read_length_))}; + + asio::async_read(ServerConnection::stream_, header, + std::bind(&ClientConnection::ProcessMessageHeader, shared_from_this(), + std::placeholders::_1)); // Ignore byte_transferred +} + +void ClientConnection::ProcessMessageHeader(const std::error_code& error) { + if (error) { + // If there was an error, disconnect the client. + ProcessError(error); + return; + } + + // If there was no error, make sure the protocol version matches. + // TODO(suquark): Don't let server die here. + ARROW_CHECK(read_version_ == kPlasmaProtocolVersion); + // Resize the message buffer to match the received length. + read_message_.resize(read_length_); + ServerConnection::bytes_read_ += read_length_; + // Wait for the message to be read. + asio::async_read(ServerConnection::stream_, asio::buffer(read_message_), + std::bind(&ClientConnection::ProcessMessageBody, shared_from_this(), + std::placeholders::_1)); +} + +void ClientConnection::ProcessMessageBody(const std::error_code& error) { + if (error) { + ProcessError(error); + return; + } + auto start = std::chrono::system_clock::now(); + ProcessMessage(read_type_, read_length_, read_message_.data()); + auto end = std::chrono::system_clock::now(); + auto interval = std::chrono::duration(end - start); + if (interval.count() > 100.0) { + ARROW_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " + << read_type_ << " took " << interval.count() << " ms."; + } +} + +void ClientConnection::ProcessError(const std::error_code& ec) { + ARROW_LOG(ERROR) + << "Failed when processing message. Disconnecting the client. Error code = " << ec; + // If there was an error, disconnect the client. + PlasmaConnection::Close(); +} + +void ClientConnection::ProcessMessage(int64_t type, int64_t length, const uint8_t* data) { + message_handler_(shared_from_this(), type, length, data); +} + +struct AsyncObjectNotificationWriteBuffer : public AsyncWriteBuffer { + static std::unique_ptr MakeDeletion( + const ObjectID& object_id) { + auto message = new std::vector(); + SerializeObjectDeletionNotification(object_id, message); + return std::unique_ptr( + new AsyncObjectNotificationWriteBuffer(message)); + } + + static std::unique_ptr MakeReady( + const ObjectID& object_id, const ObjectTableEntry& entry) { + auto message = new std::vector(); + SerializeObjectSealedNotification(object_id, entry, message); + return std::unique_ptr( + new AsyncObjectNotificationWriteBuffer(message)); + } + + void ToBuffers(std::vector& message_buffers) override { + message_buffers.push_back(asio::buffer(&size, sizeof(size))); + message_buffers.push_back(asio::buffer(*notification_msg)); + } + + std::unique_ptr> notification_msg; + int64_t size; + + protected: + explicit AsyncObjectNotificationWriteBuffer(std::vector* message) { + // Serialize the object. + notification_msg.reset(message); + size = message->size(); + handler = [](const asio::error_code& status) { + auto errno_ = status.value(); + if (!errno_) { + return; + } + if (errno_ == EAGAIN || errno_ == EWOULDBLOCK || errno_ == EINTR) { + ARROW_LOG(DEBUG) << "The socket's send buffer is full, so we are caching this " + "notification and will send it later."; + ARROW_LOG(WARNING) << "Blocked unexpectly when sending message async."; + } else { + ARROW_LOG(WARNING) << "Failed to send notification to client."; + if (errno_ == EPIPE) { + // TODO(suquark): We could probably close the socket here. + } + } + }; + } +}; + +Status ClientConnection::SendFd(int fd) { + ARROW_CHECK(send_fd(GetNativeHandle(), fd)); + return Status::OK(); +} + +void ClientConnection::SendObjectDeletionAsync(const ObjectID& object_id) { + auto raw_ptr = AsyncObjectNotificationWriteBuffer::MakeDeletion(object_id).release(); + auto write_buffer = + std::unique_ptr(static_cast(raw_ptr)); + // Attempt to send a notification about this object ID. + WriteBufferAsync(std::move(write_buffer)); +} + +void ClientConnection::SendObjectReadyAsync(const ObjectID& object_id, + const ObjectTableEntry& entry) { + auto raw_ptr = + AsyncObjectNotificationWriteBuffer::MakeReady(object_id, entry).release(); + auto write_buffer = + std::unique_ptr(static_cast(raw_ptr)); + // Attempt to send a notification about this object ID. + WriteBufferAsync(std::move(write_buffer)); +} + +} // namespace io +} // namespace plasma diff --git a/cpp/src/plasma/io/connection.h b/cpp/src/plasma/io/connection.h new file mode 100644 index 00000000000..a87a124211a --- /dev/null +++ b/cpp/src/plasma/io/connection.h @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef PLASMA_IO_CONNECTION_H +#define PLASMA_IO_CONNECTION_H + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/status.h" +#include "plasma/common.h" +#include "plasma/io/basic_connection.h" + +namespace plasma { +namespace io { + +using arrow::Status; + +using PlasmaConnection = Connection; + +Status asio_to_arrow_status(const std::error_code& ec); + +/// A generic type representing a client connection to a server. This typename +/// can be used to write messages synchronously to the server. +class ServerConnection : public PlasmaConnection { + public: + std::shared_ptr shared_from_this(); + + /// Allocate a new server connection. + /// + /// \param stream A reference to the server stream. + /// \return std::shared_ptr. + static std::shared_ptr Create(PlasmaStream&& stream); + + /// Write a message to the client. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \return Status. + Status WriteMessage(int64_t type, int64_t length, const uint8_t* message); + + /// Read a message from the client. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param message A pointer to the message vector. + /// \return Status. + Status ReadMessage(int64_t type, std::vector* message); + + /// Write a message to the client asynchronously. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \param handler A callback to run on write completion. + void WriteMessageAsync(int64_t type, int64_t length, const uint8_t* message, + const AsyncWriteCallback& handler); + + Status RecvFd(int* fd); + + Status Disconnect(); + + Status ReadNotificationMessage(std::unique_ptr& message); + + protected: + /// A private constructor for a server connection. + explicit ServerConnection(PlasmaStream&& stream); +}; + +class ClientConnection; +using MessageHandler = std::function, int64_t type, + int64_t length, const uint8_t*)>; + +/// A generic type representing a client connection on a server. In addition to +/// writing messages to the client, like in ServerConnection, this typename can +/// also be used to process messages asynchronously from client. +class ClientConnection : public ServerConnection { + public: + /// Allocate a new node client connection. + /// + /// \param stream The client stream. + /// \param message_handler A reference to the message handler. + /// \param debug_label The debug label. + /// \return std::shared_ptr. + static std::shared_ptr Create(PlasmaStream&& stream, + MessageHandler& message_handler, + const std::string& debug_label); + + std::shared_ptr shared_from_this(); + + /// Listen for and process messages from the client connection. Once a + /// message has been fully received, the client manager's + /// ProcessClientMessage handler will be called. + void ProcessMessages(); + + Status SendFd(int fd); + + inline bool ObjectIDExists(const ObjectID& object_id) { + return object_ids.find(object_id) != object_ids.end(); + } + + inline void RemoveObjectID(const ObjectID& object_id) { object_ids.erase(object_id); } + + inline int RemoveObjectIDIfExists(const ObjectID& object_id) { + auto it = object_ids.find(object_id); + if (it != object_ids.end()) { + object_ids.erase(it); + // Return 1 to indicate that the client was removed. + return 1; + } else { + // Return 0 to indicate that the client was not removed. + return 0; + } + } + + /// Send notifications about sealed objects to the subscribers. This is called + /// in SealObject. + void SendObjectReadyAsync(const ObjectID& object_id, const ObjectTableEntry& entry); + + /// Send notifications about evicted objects to the subscribers. + void SendObjectDeletionAsync(const ObjectID& object_id); + + /// Object ids that are used by this client. + std::unordered_set object_ids; + /// File descriptors that are used by this client. + std::unordered_set used_fds; + + private: + /// Process the message header from the client. + /// \param ec The returned error code. + void ProcessMessageHeader(const std::error_code& ec); + + /// Process the message body from the client. + /// \param ec The returned error code. + void ProcessMessageBody(const std::error_code& ec); + + /// Process the message from the client. + /// \param type The type of the message. + /// \param length The length of the message. + /// \param data The data buffer of the message. + void ProcessMessage(int64_t type, int64_t length, const uint8_t* data); + + /// Process an error from reading the message from the client. + /// \param ec The returned error code. + void ProcessError(const std::error_code& ec); + + /// A private constructor for a node client connection. + ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler, + const std::string& debug_label); + + /// A label used for debug messages. + const std::string debug_label_; + + /// The handler for a message from the client. + MessageHandler message_handler_; + + int64_t read_version_; + int64_t read_type_; + uint64_t read_length_; + + /// Buffers for the current message being read from the client. + std::vector read_message_; +}; + +} // namespace io +} // namespace plasma + +#endif // PLASMA_IO_CONNECTION_H diff --git a/cpp/src/plasma/malloc.cc b/cpp/src/plasma/malloc.cc index bb027a6cb90..14217894f38 100644 --- a/cpp/src/plasma/malloc.cc +++ b/cpp/src/plasma/malloc.cc @@ -29,6 +29,7 @@ #include #include +#include "arrow/util/logging.h" #include "plasma/common.h" #include "plasma/plasma.h" diff --git a/cpp/src/plasma/malloc.h b/cpp/src/plasma/malloc.h index a081190b3ef..92d6537e948 100644 --- a/cpp/src/plasma/malloc.h +++ b/cpp/src/plasma/malloc.h @@ -35,8 +35,8 @@ void GetMallocMapinfo(void* addr, int* fd, int64_t* map_length, ptrdiff_t* offse /// Get the mmap size corresponding to a specific file descriptor. /// -/// @param fd The file descriptor to look up. -/// @return The size of the corresponding memory-mapped file. +/// \param fd The file descriptor to look up. +/// \return The size of the corresponding memory-mapped file. int64_t GetMmapSize(int fd); struct MmapRecord { diff --git a/cpp/src/plasma/plasma.cc b/cpp/src/plasma/plasma.cc index d0ef4f8d317..538422d3b49 100644 --- a/cpp/src/plasma/plasma.cc +++ b/cpp/src/plasma/plasma.cc @@ -16,16 +16,7 @@ // under the License. #include "plasma/plasma.h" - -#include -#include -#include - #include "plasma/common.h" -#include "plasma/common_generated.h" -#include "plasma/protocol.h" - -namespace fb = plasma::flatbuf; namespace plasma { @@ -33,42 +24,6 @@ ObjectTableEntry::ObjectTableEntry() : pointer(nullptr), ref_count(0) {} ObjectTableEntry::~ObjectTableEntry() { pointer = nullptr; } -int WarnIfSigpipe(int status, int client_sock) { - if (status >= 0) { - return 0; - } - if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) { - ARROW_LOG(WARNING) << "Received SIGPIPE, BAD FILE DESCRIPTOR, or ECONNRESET when " - "sending a message to client on fd " - << client_sock - << ". The client on the other end may " - "have hung up."; - return errno; - } - ARROW_LOG(FATAL) << "Failed to write message to client on fd " << client_sock << "."; - return -1; // This is never reached. -} - -/** - * This will create a new ObjectInfo buffer. The first sizeof(int64_t) bytes - * of this buffer are the length of the remaining message and the - * remaining message is a serialized version of the object info. - * - * @param object_info The object info to be serialized - * @return The object info buffer. It is the caller's responsibility to free - * this buffer with "delete" after it has been used. - */ -std::unique_ptr CreateObjectInfoBuffer(fb::ObjectInfoT* object_info) { - flatbuffers::FlatBufferBuilder fbb; - auto message = fb::CreateObjectInfo(fbb, object_info); - fbb.Finish(message); - auto notification = - std::unique_ptr(new uint8_t[sizeof(int64_t) + fbb.GetSize()]); - *(reinterpret_cast(notification.get())) = fbb.GetSize(); - memcpy(notification.get() + sizeof(int64_t), fbb.GetBufferPointer(), fbb.GetSize()); - return notification; -} - ObjectTableEntry* GetObjectTableEntry(PlasmaStoreInfo* store_info, const ObjectID& object_id) { auto it = store_info->objects.find(object_id); diff --git a/cpp/src/plasma/plasma.h b/cpp/src/plasma/plasma.h index e23969d05ff..306ae27d345 100644 --- a/cpp/src/plasma/plasma.h +++ b/cpp/src/plasma/plasma.h @@ -18,15 +18,6 @@ #ifndef PLASMA_PLASMA_H #define PLASMA_PLASMA_H -#include -#include -#include -#include -#include -#include -#include -#include // pid_t - #include #include #include @@ -35,8 +26,6 @@ #include "plasma/compat.h" #include "arrow/status.h" -#include "arrow/util/logging.h" -#include "arrow/util/macros.h" #include "plasma/common.h" #ifdef PLASMA_CUDA @@ -45,32 +34,9 @@ using arrow::cuda::CudaIpcMemHandle; namespace plasma { -namespace flatbuf { -struct ObjectInfoT; -} // namespace flatbuf - -#define HANDLE_SIGPIPE(s, fd_) \ - do { \ - Status _s = (s); \ - if (!_s.ok()) { \ - if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) { \ - ARROW_LOG(WARNING) \ - << "Received SIGPIPE, BAD FILE DESCRIPTOR, or ECONNRESET when " \ - "sending a message to client on fd " \ - << fd_ \ - << ". " \ - "The client on the other end may have hung up."; \ - } else { \ - return _s; \ - } \ - } \ - } while (0); - /// Allocation granularity used in plasma for object allocation. constexpr int64_t kBlockSize = 64; -struct Client; - // TODO(pcm): Replace this by the flatbuffers message PlasmaObjectSpec. struct PlasmaObject { #ifdef PLASMA_CUDA @@ -116,31 +82,13 @@ struct PlasmaStoreInfo { /// Get an entry from the object table and return NULL if the object_id /// is not present. /// -/// @param store_info The PlasmaStoreInfo that contains the object table. -/// @param object_id The object_id of the entry we are looking for. -/// @return The entry associated with the object_id or NULL if the object_id +/// \param store_info The PlasmaStoreInfo that contains the object table. +/// \param object_id The object_id of the entry we are looking for. +/// \return The entry associated with the object_id or NULL if the object_id /// is not present. ObjectTableEntry* GetObjectTableEntry(PlasmaStoreInfo* store_info, const ObjectID& object_id); -/// Print a warning if the status is less than zero. This should be used to check -/// the success of messages sent to plasma clients. We print a warning instead of -/// failing because the plasma clients are allowed to die. This is used to handle -/// situations where the store writes to a client file descriptor, and the client -/// may already have disconnected. If we have processed the disconnection and -/// closed the file descriptor, we should get a BAD FILE DESCRIPTOR error. If we -/// have not, then we should get a SIGPIPE. If we write to a TCP socket that -/// isn't connected yet, then we should get an ECONNRESET. -/// -/// @param status The status to check. If it is less less than zero, we will -/// print a warning. -/// @param client_sock The client socket. This is just used to print some extra -/// information. -/// @return The errno set. -int WarnIfSigpipe(int status, int client_sock); - -std::unique_ptr CreateObjectInfoBuffer(flatbuf::ObjectInfoT* object_info); - } // namespace plasma #endif // PLASMA_PLASMA_H diff --git a/cpp/src/plasma/protocol.cc b/cpp/src/plasma/protocol.cc index a8786477182..0d489028df2 100644 --- a/cpp/src/plasma/protocol.cc +++ b/cpp/src/plasma/protocol.cc @@ -19,11 +19,12 @@ #include +#include "arrow/util/logging.h" #include "flatbuffers/flatbuffers.h" #include "plasma/plasma_generated.h" #include "plasma/common.h" -#include "plasma/io.h" +#include "plasma/io/connection.h" #ifdef PLASMA_CUDA #include "arrow/gpu/cuda_api.h" @@ -52,15 +53,6 @@ ToFlatbuffer(flatbuffers::FlatBufferBuilder* fbb, const ObjectID* object_ids, return fbb->CreateVector(results); } -Status PlasmaReceive(int sock, MessageType message_type, std::vector* buffer) { - MessageType type; - RETURN_NOT_OK(ReadMessage(sock, &type, buffer)); - ARROW_CHECK(type == message_type) - << "type = " << static_cast(type) - << ", message_type = " << static_cast(message_type); - return Status::OK(); -} - // Helper function to create a vector of elements from Data (Request/Reply struct). // The Getter function is used to extract one element from Data. template @@ -73,13 +65,6 @@ void ToVector(const Data& request, std::vector* out, const Getter& getter) { } } -template -Status PlasmaSend(int sock, MessageType message_type, flatbuffers::FlatBufferBuilder* fbb, - const Message& message) { - fbb->Finish(message); - return WriteMessage(sock, message_type, fbb->GetSize(), fbb->GetBufferPointer()); -} - Status PlasmaErrorStatus(fb::PlasmaError plasma_error) { switch (plasma_error) { case fb::PlasmaError::OK: @@ -96,17 +81,39 @@ Status PlasmaErrorStatus(fb::PlasmaError plasma_error) { return Status::OK(); } +Status PlasmaSend(const std::shared_ptr& client, + MessageType message_type, flatbuffers::FlatBufferBuilder* fbb) { + if (fbb) { + return client->WriteMessage(static_cast(message_type), fbb->GetSize(), + fbb->GetBufferPointer()); + } else { + return client->WriteMessage(static_cast(message_type), 0, NULLPTR); + } +} + +Status PlasmaSend(const std::shared_ptr& client, + MessageType message_type, flatbuffers::FlatBufferBuilder* fbb) { + if (fbb) { + return client->WriteMessage(static_cast(message_type), fbb->GetSize(), + fbb->GetBufferPointer()); + } else { + return client->WriteMessage(static_cast(message_type), 0, NULLPTR); + } +} + // Create messages. -Status SendCreateRequest(int sock, ObjectID object_id, int64_t data_size, - int64_t metadata_size, int device_num) { +Status SendCreateRequest(const std::shared_ptr& client, + ObjectID object_id, int64_t data_size, int64_t metadata_size, + int device_num) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaCreateRequest(fbb, fbb.CreateString(object_id.binary()), data_size, metadata_size, device_num); - return PlasmaSend(sock, MessageType::PlasmaCreateRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaCreateRequest, &fbb); } -Status ReadCreateRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateRequest(const uint8_t* data, size_t size, ObjectID* object_id, int64_t* data_size, int64_t* metadata_size, int* device_num) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -118,8 +125,9 @@ Status ReadCreateRequest(uint8_t* data, size_t size, ObjectID* object_id, return Status::OK(); } -Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, - PlasmaError error_code, int64_t mmap_size) { +Status SendCreateReply(const std::shared_ptr& client, + ObjectID object_id, PlasmaObject* object, PlasmaError error_code, + int64_t mmap_size) { flatbuffers::FlatBufferBuilder fbb; PlasmaObjectSpec plasma_object(object->store_fd, object->data_offset, object->data_size, object->metadata_offset, object->metadata_size, @@ -148,10 +156,11 @@ Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, #endif } auto message = crb.Finish(); - return PlasmaSend(sock, MessageType::PlasmaCreateReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaCreateReply, &fbb); } -Status ReadCreateReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateReply(const uint8_t* data, size_t size, ObjectID* object_id, PlasmaObject* object, int* store_fd, int64_t* mmap_size) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -176,18 +185,19 @@ Status ReadCreateReply(uint8_t* data, size_t size, ObjectID* object_id, return PlasmaErrorStatus(message->error()); } -Status SendCreateAndSealRequest(int sock, const ObjectID& object_id, - const std::string& data, const std::string& metadata, - unsigned char* digest) { +Status SendCreateAndSealRequest(const std::shared_ptr& client, + const ObjectID& object_id, const std::string& data, + const std::string& metadata, unsigned char* digest) { flatbuffers::FlatBufferBuilder fbb; auto digest_string = fbb.CreateString(reinterpret_cast(digest), kDigestSize); auto message = fb::CreatePlasmaCreateAndSealRequest( fbb, fbb.CreateString(object_id.binary()), fbb.CreateString(data), fbb.CreateString(metadata), digest_string); - return PlasmaSend(sock, MessageType::PlasmaCreateAndSealRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaCreateAndSealRequest, &fbb); } -Status ReadCreateAndSealRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateAndSealRequest(const uint8_t* data, size_t size, ObjectID* object_id, std::string* object_data, std::string* metadata, unsigned char* digest) { DCHECK(data); @@ -202,26 +212,30 @@ Status ReadCreateAndSealRequest(uint8_t* data, size_t size, ObjectID* object_id, return Status::OK(); } -Status SendCreateAndSealReply(int sock, PlasmaError error) { +Status SendCreateAndSealReply(const std::shared_ptr& client, + PlasmaError error) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaCreateAndSealReply(fbb, static_cast(error)); - return PlasmaSend(sock, MessageType::PlasmaCreateAndSealReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaCreateAndSealReply, &fbb); } -Status ReadCreateAndSealReply(uint8_t* data, size_t size) { +Status ReadCreateAndSealReply(const uint8_t* data, size_t size) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); return PlasmaErrorStatus(message->error()); } -Status SendAbortRequest(int sock, ObjectID object_id) { +Status SendAbortRequest(const std::shared_ptr& client, + ObjectID object_id) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaAbortRequest(fbb, fbb.CreateString(object_id.binary())); - return PlasmaSend(sock, MessageType::PlasmaAbortRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaAbortRequest, &fbb); } -Status ReadAbortRequest(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadAbortRequest(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -229,13 +243,15 @@ Status ReadAbortRequest(uint8_t* data, size_t size, ObjectID* object_id) { return Status::OK(); } -Status SendAbortReply(int sock, ObjectID object_id) { +Status SendAbortReply(const std::shared_ptr& client, + ObjectID object_id) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaAbortReply(fbb, fbb.CreateString(object_id.binary())); - return PlasmaSend(sock, MessageType::PlasmaAbortReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaAbortReply, &fbb); } -Status ReadAbortReply(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadAbortReply(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -245,15 +261,17 @@ Status ReadAbortReply(uint8_t* data, size_t size, ObjectID* object_id) { // Seal messages. -Status SendSealRequest(int sock, ObjectID object_id, unsigned char* digest) { +Status SendSealRequest(const std::shared_ptr& client, + ObjectID object_id, unsigned char* digest) { flatbuffers::FlatBufferBuilder fbb; auto digest_string = fbb.CreateString(reinterpret_cast(digest), kDigestSize); auto message = fb::CreatePlasmaSealRequest(fbb, fbb.CreateString(object_id.binary()), digest_string); - return PlasmaSend(sock, MessageType::PlasmaSealRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaSealRequest, &fbb); } -Status ReadSealRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadSealRequest(const uint8_t* data, size_t size, ObjectID* object_id, unsigned char* digest) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -264,14 +282,16 @@ Status ReadSealRequest(uint8_t* data, size_t size, ObjectID* object_id, return Status::OK(); } -Status SendSealReply(int sock, ObjectID object_id, PlasmaError error) { +Status SendSealReply(const std::shared_ptr& client, ObjectID object_id, + PlasmaError error) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaSealReply(fbb, fbb.CreateString(object_id.binary()), error); - return PlasmaSend(sock, MessageType::PlasmaSealReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaSealReply, &fbb); } -Status ReadSealReply(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadSealReply(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -281,14 +301,16 @@ Status ReadSealReply(uint8_t* data, size_t size, ObjectID* object_id) { // Release messages. -Status SendReleaseRequest(int sock, ObjectID object_id) { +Status SendReleaseRequest(const std::shared_ptr& client, + ObjectID object_id) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaReleaseRequest(fbb, fbb.CreateString(object_id.binary())); - return PlasmaSend(sock, MessageType::PlasmaReleaseRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaReleaseRequest, &fbb); } -Status ReadReleaseRequest(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadReleaseRequest(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -296,14 +318,16 @@ Status ReadReleaseRequest(uint8_t* data, size_t size, ObjectID* object_id) { return Status::OK(); } -Status SendReleaseReply(int sock, ObjectID object_id, PlasmaError error) { +Status SendReleaseReply(const std::shared_ptr& client, + ObjectID object_id, PlasmaError error) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaReleaseReply(fbb, fbb.CreateString(object_id.binary()), error); - return PlasmaSend(sock, MessageType::PlasmaReleaseReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaReleaseReply, &fbb); } -Status ReadReleaseReply(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadReleaseReply(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -313,15 +337,18 @@ Status ReadReleaseReply(uint8_t* data, size_t size, ObjectID* object_id) { // Delete objects messages. -Status SendDeleteRequest(int sock, const std::vector& object_ids) { +Status SendDeleteRequest(const std::shared_ptr& client, + const std::vector& object_ids) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaDeleteRequest( fbb, static_cast(object_ids.size()), ToFlatbuffer(&fbb, &object_ids[0], object_ids.size())); - return PlasmaSend(sock, MessageType::PlasmaDeleteRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaDeleteRequest, &fbb); } -Status ReadDeleteRequest(uint8_t* data, size_t size, std::vector* object_ids) { +Status ReadDeleteRequest(const uint8_t* data, size_t size, + std::vector* object_ids) { using fb::PlasmaDeleteRequest; DCHECK(data); @@ -334,7 +361,8 @@ Status ReadDeleteRequest(uint8_t* data, size_t size, std::vector* obje return Status::OK(); } -Status SendDeleteReply(int sock, const std::vector& object_ids, +Status SendDeleteReply(const std::shared_ptr& client, + const std::vector& object_ids, const std::vector& errors) { DCHECK(object_ids.size() == errors.size()); flatbuffers::FlatBufferBuilder fbb; @@ -342,10 +370,12 @@ Status SendDeleteReply(int sock, const std::vector& object_ids, fbb, static_cast(object_ids.size()), ToFlatbuffer(&fbb, &object_ids[0], object_ids.size()), fbb.CreateVector(reinterpret_cast(&errors[0]), object_ids.size())); - return PlasmaSend(sock, MessageType::PlasmaDeleteReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaDeleteReply, &fbb); } -Status ReadDeleteReply(uint8_t* data, size_t size, std::vector* object_ids, +Status ReadDeleteReply(const uint8_t* data, size_t size, + std::vector* object_ids, std::vector* errors) { using fb::PlasmaDeleteReply; @@ -365,14 +395,16 @@ Status ReadDeleteReply(uint8_t* data, size_t size, std::vector* object // Contains messages. -Status SendContainsRequest(int sock, ObjectID object_id) { +Status SendContainsRequest(const std::shared_ptr& client, + ObjectID object_id) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaContainsRequest(fbb, fbb.CreateString(object_id.binary())); - return PlasmaSend(sock, MessageType::PlasmaContainsRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaContainsRequest, &fbb); } -Status ReadContainsRequest(uint8_t* data, size_t size, ObjectID* object_id) { +Status ReadContainsRequest(const uint8_t* data, size_t size, ObjectID* object_id) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -380,14 +412,16 @@ Status ReadContainsRequest(uint8_t* data, size_t size, ObjectID* object_id) { return Status::OK(); } -Status SendContainsReply(int sock, ObjectID object_id, bool has_object) { +Status SendContainsReply(const std::shared_ptr& client, + ObjectID object_id, bool has_object) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaContainsReply(fbb, fbb.CreateString(object_id.binary()), has_object); - return PlasmaSend(sock, MessageType::PlasmaContainsReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaContainsReply, &fbb); } -Status ReadContainsReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadContainsReply(const uint8_t* data, size_t size, ObjectID* object_id, bool* has_object) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -399,15 +433,17 @@ Status ReadContainsReply(uint8_t* data, size_t size, ObjectID* object_id, // List messages. -Status SendListRequest(int sock) { +Status SendListRequest(const std::shared_ptr& client) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaListRequest(fbb); - return PlasmaSend(sock, MessageType::PlasmaListRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaListRequest, &fbb); } -Status ReadListRequest(uint8_t* data, size_t size) { return Status::OK(); } +Status ReadListRequest(const uint8_t* data, size_t size) { return Status::OK(); } -Status SendListReply(int sock, const ObjectTable& objects) { +Status SendListReply(const std::shared_ptr& client, + const ObjectTable& objects) { flatbuffers::FlatBufferBuilder fbb; std::vector> object_infos; for (auto const& entry : objects) { @@ -422,10 +458,11 @@ Status SendListReply(int sock, const ObjectTable& objects) { object_infos.push_back(info); } auto message = fb::CreatePlasmaListReply(fbb, fbb.CreateVector(object_infos)); - return PlasmaSend(sock, MessageType::PlasmaListReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaListReply, &fbb); } -Status ReadListReply(uint8_t* data, size_t size, ObjectTable* objects) { +Status ReadListReply(const uint8_t* data, size_t size, ObjectTable* objects) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -446,21 +483,24 @@ Status ReadListReply(uint8_t* data, size_t size, ObjectTable* objects) { // Connect messages. -Status SendConnectRequest(int sock) { +Status SendConnectRequest(const std::shared_ptr& client) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaConnectRequest(fbb); - return PlasmaSend(sock, MessageType::PlasmaConnectRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaConnectRequest, &fbb); } -Status ReadConnectRequest(uint8_t* data) { return Status::OK(); } +Status ReadConnectRequest(const uint8_t* data) { return Status::OK(); } -Status SendConnectReply(int sock, int64_t memory_capacity) { +Status SendConnectReply(const std::shared_ptr& client, + int64_t memory_capacity) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaConnectReply(fbb, memory_capacity); - return PlasmaSend(sock, MessageType::PlasmaConnectReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaConnectReply, &fbb); } -Status ReadConnectReply(uint8_t* data, size_t size, int64_t* memory_capacity) { +Status ReadConnectReply(const uint8_t* data, size_t size, int64_t* memory_capacity) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -470,13 +510,15 @@ Status ReadConnectReply(uint8_t* data, size_t size, int64_t* memory_capacity) { // Evict messages. -Status SendEvictRequest(int sock, int64_t num_bytes) { +Status SendEvictRequest(const std::shared_ptr& client, + int64_t num_bytes) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaEvictRequest(fbb, num_bytes); - return PlasmaSend(sock, MessageType::PlasmaEvictRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaEvictRequest, &fbb); } -Status ReadEvictRequest(uint8_t* data, size_t size, int64_t* num_bytes) { +Status ReadEvictRequest(const uint8_t* data, size_t size, int64_t* num_bytes) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -484,13 +526,15 @@ Status ReadEvictRequest(uint8_t* data, size_t size, int64_t* num_bytes) { return Status::OK(); } -Status SendEvictReply(int sock, int64_t num_bytes) { +Status SendEvictReply(const std::shared_ptr& client, + int64_t num_bytes) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaEvictReply(fbb, num_bytes); - return PlasmaSend(sock, MessageType::PlasmaEvictReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaEvictReply, &fbb); } -Status ReadEvictReply(uint8_t* data, size_t size, int64_t& num_bytes) { +Status ReadEvictReply(const uint8_t* data, size_t size, int64_t& num_bytes) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -500,15 +544,17 @@ Status ReadEvictReply(uint8_t* data, size_t size, int64_t& num_bytes) { // Get messages. -Status SendGetRequest(int sock, const ObjectID* object_ids, int64_t num_objects, +Status SendGetRequest(const std::shared_ptr& client, + const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaGetRequest( fbb, ToFlatbuffer(&fbb, object_ids, num_objects), timeout_ms); - return PlasmaSend(sock, MessageType::PlasmaGetRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaGetRequest, &fbb); } -Status ReadGetRequest(uint8_t* data, size_t size, std::vector& object_ids, +Status ReadGetRequest(const uint8_t* data, size_t size, std::vector& object_ids, int64_t* timeout_ms) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -521,7 +567,8 @@ Status ReadGetRequest(uint8_t* data, size_t size, std::vector& object_ return Status::OK(); } -Status SendGetReply(int sock, ObjectID object_ids[], +Status SendGetReply(const std::shared_ptr& client, + ObjectID object_ids[], std::unordered_map& plasma_objects, int64_t num_objects, const std::vector& store_fds, const std::vector& mmap_sizes) { @@ -547,10 +594,12 @@ Status SendGetReply(int sock, ObjectID object_ids[], fbb, ToFlatbuffer(&fbb, object_ids, num_objects), fbb.CreateVectorOfStructs(objects.data(), num_objects), fbb.CreateVector(store_fds), fbb.CreateVector(mmap_sizes), fbb.CreateVector(handles)); - return PlasmaSend(sock, MessageType::PlasmaGetReply, &fbb, message); + + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaGetReply, &fbb); } -Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], +Status ReadGetReply(const uint8_t* data, size_t size, ObjectID object_ids[], PlasmaObject plasma_objects[], int64_t num_objects, std::vector& store_fds, std::vector& mmap_sizes) { DCHECK(data); @@ -587,26 +636,20 @@ Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], return Status::OK(); } -// Subscribe messages. - -Status SendSubscribeRequest(int sock) { - flatbuffers::FlatBufferBuilder fbb; - auto message = fb::CreatePlasmaSubscribeRequest(fbb); - return PlasmaSend(sock, MessageType::PlasmaSubscribeRequest, &fbb, message); -} - // Data messages. -Status SendDataRequest(int sock, ObjectID object_id, const char* address, int port) { +Status SendDataRequest(const std::shared_ptr& client, + ObjectID object_id, const char* address, int port) { flatbuffers::FlatBufferBuilder fbb; auto addr = fbb.CreateString(address, strlen(address)); auto message = fb::CreatePlasmaDataRequest(fbb, fbb.CreateString(object_id.binary()), addr, port); - return PlasmaSend(sock, MessageType::PlasmaDataRequest, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaDataRequest, &fbb); } -Status ReadDataRequest(uint8_t* data, size_t size, ObjectID* object_id, char** address, - int* port) { +Status ReadDataRequest(const uint8_t* data, size_t size, ObjectID* object_id, + char** address, int* port) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(VerifyFlatbuffer(message, data, size)); @@ -617,15 +660,16 @@ Status ReadDataRequest(uint8_t* data, size_t size, ObjectID* object_id, char** a return Status::OK(); } -Status SendDataReply(int sock, ObjectID object_id, int64_t object_size, - int64_t metadata_size) { +Status SendDataReply(const std::shared_ptr& client, ObjectID object_id, + int64_t object_size, int64_t metadata_size) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaDataReply(fbb, fbb.CreateString(object_id.binary()), object_size, metadata_size); - return PlasmaSend(sock, MessageType::PlasmaDataReply, &fbb, message); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaDataReply, &fbb); } -Status ReadDataReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadDataReply(const uint8_t* data, size_t size, ObjectID* object_id, int64_t* object_size, int64_t* metadata_size) { DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -636,4 +680,39 @@ Status ReadDataReply(uint8_t* data, size_t size, ObjectID* object_id, return Status::OK(); } +Status SendSubscribeRequest(const std::shared_ptr& client) { + // Subscribe messages. + flatbuffers::FlatBufferBuilder fbb; + auto message = fb::CreatePlasmaSubscribeRequest(fbb); + fbb.Finish(message); + return PlasmaSend(client, MessageType::PlasmaSubscribeRequest, &fbb); +} + +void SerializeObjectDeletionNotification(const ObjectID& object_id, + std::vector* serialized) { + flatbuf::ObjectInfoT info; + info.object_id = object_id.binary(); + info.is_deletion = true; + flatbuffers::FlatBufferBuilder fbb; + auto message = fb::CreateObjectInfo(fbb, &info); + fbb.Finish(message); + serialized->resize(fbb.GetSize()); + serialized->assign(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); +} + +void SerializeObjectSealedNotification(const ObjectID& object_id, + const ObjectTableEntry& entry, + std::vector* serialized) { + flatbuf::ObjectInfoT info; + info.object_id = object_id.binary(); + info.data_size = entry.data_size; + info.metadata_size = entry.metadata_size; + info.digest = std::string(reinterpret_cast(&entry.digest[0]), kDigestSize); + flatbuffers::FlatBufferBuilder fbb; + auto message = fb::CreateObjectInfo(fbb, &info); + fbb.Finish(message); + serialized->resize(fbb.GetSize()); + serialized->assign(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); +} + } // namespace plasma diff --git a/cpp/src/plasma/protocol.h b/cpp/src/plasma/protocol.h index 0362bd47797..e9abb1935a0 100644 --- a/cpp/src/plasma/protocol.h +++ b/cpp/src/plasma/protocol.h @@ -24,168 +24,192 @@ #include #include "arrow/status.h" +#include "plasma/io/connection.h" #include "plasma/plasma.h" #include "plasma/plasma_generated.h" namespace plasma { using arrow::Status; - -using flatbuf::MessageType; using flatbuf::PlasmaError; +using io::ClientConnection; +using io::ServerConnection; template -bool VerifyFlatbuffer(T* object, uint8_t* data, size_t size) { +bool VerifyFlatbuffer(T* object, const uint8_t* data, size_t size) { flatbuffers::Verifier verifier(data, size); return object->Verify(verifier); } -/* Plasma receive message. */ - -Status PlasmaReceive(int sock, MessageType message_type, std::vector* buffer); - /* Plasma Create message functions. */ -Status SendCreateRequest(int sock, ObjectID object_id, int64_t data_size, - int64_t metadata_size, int device_num); +Status SendCreateRequest(const std::shared_ptr& client, + ObjectID object_id, int64_t data_size, int64_t metadata_size, + int device_num); -Status ReadCreateRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateRequest(const uint8_t* data, size_t size, ObjectID* object_id, int64_t* data_size, int64_t* metadata_size, int* device_num); -Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, - PlasmaError error, int64_t mmap_size); +Status SendCreateReply(const std::shared_ptr& client, + ObjectID object_id, PlasmaObject* object, PlasmaError error, + int64_t mmap_size); -Status ReadCreateReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateReply(const uint8_t* data, size_t size, ObjectID* object_id, PlasmaObject* object, int* store_fd, int64_t* mmap_size); -Status SendCreateAndSealRequest(int sock, const ObjectID& object_id, - const std::string& data, const std::string& metadata, - unsigned char* digest); +Status SendCreateAndSealRequest(const std::shared_ptr& client, + const ObjectID& object_id, const std::string& data, + const std::string& metadata, unsigned char* digest); -Status ReadCreateAndSealRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadCreateAndSealRequest(const uint8_t* data, size_t size, ObjectID* object_id, std::string* object_data, std::string* metadata, unsigned char* digest); -Status SendCreateAndSealReply(int sock, PlasmaError error); +Status SendCreateAndSealReply(const std::shared_ptr& client, + PlasmaError error); -Status ReadCreateAndSealReply(uint8_t* data, size_t size); +Status ReadCreateAndSealReply(const uint8_t* data, size_t size); -Status SendAbortRequest(int sock, ObjectID object_id); +Status SendAbortRequest(const std::shared_ptr& client, + ObjectID object_id); -Status ReadAbortRequest(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadAbortRequest(const uint8_t* data, size_t size, ObjectID* object_id); -Status SendAbortReply(int sock, ObjectID object_id); +Status SendAbortReply(const std::shared_ptr& client, + ObjectID object_id); -Status ReadAbortReply(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadAbortReply(const uint8_t* data, size_t size, ObjectID* object_id); /* Plasma Seal message functions. */ -Status SendSealRequest(int sock, ObjectID object_id, unsigned char* digest); +Status SendSealRequest(const std::shared_ptr& client, + ObjectID object_id, unsigned char* digest); -Status ReadSealRequest(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadSealRequest(const uint8_t* data, size_t size, ObjectID* object_id, unsigned char* digest); -Status SendSealReply(int sock, ObjectID object_id, PlasmaError error); +Status SendSealReply(const std::shared_ptr& client, ObjectID object_id, + PlasmaError error); -Status ReadSealReply(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadSealReply(const uint8_t* data, size_t size, ObjectID* object_id); /* Plasma Get message functions. */ -Status SendGetRequest(int sock, const ObjectID* object_ids, int64_t num_objects, +Status SendGetRequest(const std::shared_ptr& client, + const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms); -Status ReadGetRequest(uint8_t* data, size_t size, std::vector& object_ids, +Status ReadGetRequest(const uint8_t* data, size_t size, std::vector& object_ids, int64_t* timeout_ms); -Status SendGetReply(int sock, ObjectID object_ids[], +Status SendGetReply(const std::shared_ptr& client, + ObjectID object_ids[], std::unordered_map& plasma_objects, int64_t num_objects, const std::vector& store_fds, const std::vector& mmap_sizes); -Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], +Status ReadGetReply(const uint8_t* data, size_t size, ObjectID object_ids[], PlasmaObject plasma_objects[], int64_t num_objects, std::vector& store_fds, std::vector& mmap_sizes); /* Plasma Release message functions. */ -Status SendReleaseRequest(int sock, ObjectID object_id); +Status SendReleaseRequest(const std::shared_ptr& client, + ObjectID object_id); -Status ReadReleaseRequest(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadReleaseRequest(const uint8_t* data, size_t size, ObjectID* object_id); -Status SendReleaseReply(int sock, ObjectID object_id, PlasmaError error); +Status SendReleaseReply(const std::shared_ptr& client, + ObjectID object_id, PlasmaError error); -Status ReadReleaseReply(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadReleaseReply(const uint8_t* data, size_t size, ObjectID* object_id); /* Plasma Delete objects message functions. */ -Status SendDeleteRequest(int sock, const std::vector& object_ids); +Status SendDeleteRequest(const std::shared_ptr& client, + const std::vector& object_ids); -Status ReadDeleteRequest(uint8_t* data, size_t size, std::vector* object_ids); +Status ReadDeleteRequest(const uint8_t* data, size_t size, + std::vector* object_ids); -Status SendDeleteReply(int sock, const std::vector& object_ids, +Status SendDeleteReply(const std::shared_ptr& client, + const std::vector& object_ids, const std::vector& errors); -Status ReadDeleteReply(uint8_t* data, size_t size, std::vector* object_ids, +Status ReadDeleteReply(const uint8_t* data, size_t size, + std::vector* object_ids, std::vector* errors); /* Plasma Constains message functions. */ -Status SendContainsRequest(int sock, ObjectID object_id); +Status SendContainsRequest(const std::shared_ptr& client, + ObjectID object_id); -Status ReadContainsRequest(uint8_t* data, size_t size, ObjectID* object_id); +Status ReadContainsRequest(const uint8_t* data, size_t size, ObjectID* object_id); -Status SendContainsReply(int sock, ObjectID object_id, bool has_object); +Status SendContainsReply(const std::shared_ptr& client, + ObjectID object_id, bool has_object); -Status ReadContainsReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadContainsReply(const uint8_t* data, size_t size, ObjectID* object_id, bool* has_object); /* Plasma List message functions. */ -Status SendListRequest(int sock); +Status SendListRequest(const std::shared_ptr& client); -Status ReadListRequest(uint8_t* data, size_t size); +Status ReadListRequest(const uint8_t* data, size_t size); -Status SendListReply(int sock, const ObjectTable& objects); +Status SendListReply(const std::shared_ptr& client, + const ObjectTable& objects); -Status ReadListReply(uint8_t* data, size_t size, ObjectTable* objects); +Status ReadListReply(const uint8_t* data, size_t size, ObjectTable* objects); /* Plasma Connect message functions. */ -Status SendConnectRequest(int sock); +Status SendConnectRequest(const std::shared_ptr& client); -Status ReadConnectRequest(uint8_t* data, size_t size); +Status ReadConnectRequest(const uint8_t* data, size_t size); -Status SendConnectReply(int sock, int64_t memory_capacity); +Status SendConnectReply(const std::shared_ptr& client, + int64_t memory_capacity); -Status ReadConnectReply(uint8_t* data, size_t size, int64_t* memory_capacity); +Status ReadConnectReply(const uint8_t* data, size_t size, int64_t* memory_capacity); /* Plasma Evict message functions (no reply so far). */ -Status SendEvictRequest(int sock, int64_t num_bytes); - -Status ReadEvictRequest(uint8_t* data, size_t size, int64_t* num_bytes); +Status SendEvictRequest(const std::shared_ptr& client, + int64_t num_bytes); -Status SendEvictReply(int sock, int64_t num_bytes); +Status ReadEvictRequest(const uint8_t* data, size_t size, int64_t* num_bytes); -Status ReadEvictReply(uint8_t* data, size_t size, int64_t& num_bytes); +Status SendEvictReply(const std::shared_ptr& client, int64_t num_bytes); -/* Plasma Subscribe message functions. */ - -Status SendSubscribeRequest(int sock); +Status ReadEvictReply(const uint8_t* data, size_t size, int64_t& num_bytes); /* Data messages. */ -Status SendDataRequest(int sock, ObjectID object_id, const char* address, int port); +Status SendDataRequest(const std::shared_ptr& client, + ObjectID object_id, const char* address, int port); -Status ReadDataRequest(uint8_t* data, size_t size, ObjectID* object_id, char** address, - int* port); +Status ReadDataRequest(const uint8_t* data, size_t size, ObjectID* object_id, + char** address, int* port); -Status SendDataReply(int sock, ObjectID object_id, int64_t object_size, - int64_t metadata_size); +Status SendDataReply(const std::shared_ptr& client, ObjectID object_id, + int64_t object_size, int64_t metadata_size); -Status ReadDataReply(uint8_t* data, size_t size, ObjectID* object_id, +Status ReadDataReply(const uint8_t* data, size_t size, ObjectID* object_id, int64_t* object_size, int64_t* metadata_size); +/* Plasma notification message functions. */ + +Status SendSubscribeRequest(const std::shared_ptr& client); + +void SerializeObjectDeletionNotification(const ObjectID& object_id, + std::vector* serialized); + +void SerializeObjectSealedNotification(const ObjectID& object_id, + const ObjectTableEntry& entry, + std::vector* serialized); } // namespace plasma #endif /* PLASMA_PROTOCOL */ diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index 4d866530b48..b83697c5ff5 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -28,23 +28,16 @@ #include "plasma/store.h" -#include -#include #include -#include #include #include #include -#include -#include -#include +#include +#ifdef __linux__ #include -#include -#include -#include - +#endif +#include #include -#include #include #include #include @@ -53,13 +46,11 @@ #include #include "arrow/status.h" - +#include "arrow/util/logging.h" #include "plasma/common.h" -#include "plasma/common_generated.h" -#include "plasma/fling.h" -#include "plasma/io.h" #include "plasma/malloc.h" #include "plasma/plasma_allocator.h" +#include "plasma/plasma_generated.h" #include "plasma/protocol.h" #ifdef PLASMA_CUDA @@ -73,19 +64,72 @@ using arrow::cuda::CudaDeviceManager; using arrow::util::ArrowLog; using arrow::util::ArrowLogLevel; -namespace fb = plasma::flatbuf; - namespace plasma { +using flatbuf::MessageType; + void SetMallocGranularity(int value); struct GetRequest { - GetRequest(Client* client, const std::vector& object_ids); + GetRequest(asio::io_context& io_context, + const std::shared_ptr& client, + const std::vector& object_ids) + : client(client), + object_ids(object_ids.begin(), object_ids.end()), + objects(object_ids.size()), + num_satisfied(0), + timer_(io_context) { + std::unordered_set unique_ids(object_ids.begin(), object_ids.end()); + num_objects_to_wait_for = unique_ids.size(); + } + + void ReturnFromGet() { + // Figure out how many file descriptors we need to send. + std::unordered_set fds_to_send; + std::vector store_fds; + std::vector mmap_sizes; + for (const auto& object_id : object_ids) { + PlasmaObject& object = objects[object_id]; + int fd = object.store_fd; + if (object.data_size != -1 && fds_to_send.count(fd) == 0 && fd != -1) { + fds_to_send.insert(fd); + store_fds.push_back(fd); + mmap_sizes.push_back(GetMmapSize(fd)); + } + } + + // Send the get reply to the client. + Status s = SendGetReply(client, &object_ids[0], objects, object_ids.size(), store_fds, + mmap_sizes); + // If we successfully sent the get reply message to the client, then also send + // the file descriptors. + if (s.ok()) { + // Send all of the file descriptors for the present objects. + for (int store_fd : store_fds) { + // Only send the file descriptor if it hasn't been sent (see analogous + // logic in GetStoreFd in client.cc). + if (client->used_fds.find(store_fd) == client->used_fds.end()) { + auto status = client->SendFd(store_fd); + if (!status.ok()) { + ARROW_LOG(ERROR) << "Failed to send a mmap fd to client"; + } + client->used_fds.insert(store_fd); + } + } + } + } + + void AsyncWait(int64_t timeout_ms, + std::function on_timeout) { + // Set an expiry time relative to now. + timer_.expires_from_now(std::chrono::milliseconds(timeout_ms)); + timer_.async_wait(on_timeout); + } + + void CancelTimer() { timer_.cancel(); } + /// The client that called get. - Client* client; - /// The ID of the timer that will time out and cause this wait to return to - /// the client if it hasn't already returned. - int64_t timer; + std::shared_ptr client; /// The object IDs involved in this request. This is used in the reply. std::vector object_ids; /// The object information for the objects in this request. This is used in @@ -96,29 +140,29 @@ struct GetRequest { /// The number of object requests in this wait request that are already /// satisfied. int64_t num_satisfied; -}; - -GetRequest::GetRequest(Client* client, const std::vector& object_ids) - : client(client), - timer(-1), - object_ids(object_ids.begin(), object_ids.end()), - objects(object_ids.size()), - num_satisfied(0) { - std::unordered_set unique_ids(object_ids.begin(), object_ids.end()); - num_objects_to_wait_for = unique_ids.size(); -} -Client::Client(int fd) : fd(fd), notification_fd(-1) {} + private: + /// The timer that will time out and cause this wait to return to + /// the client if it hasn't already returned. + asio::steady_timer timer_; +}; -PlasmaStore::PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled, - const std::string& socket_name, +PlasmaStore::PlasmaStore(asio::io_context& io_context, std::string directory, + bool hugepages_enabled, const std::string& stream_name, std::shared_ptr external_store) - : loop_(loop), eviction_policy_(&store_info_), external_store_(external_store) { + : eviction_policy_(&store_info_), + external_store_(external_store), + io_context_(io_context), + stream_name_(stream_name), + acceptor_(io::CreateLocalAcceptor(io_context, stream_name)), + stream_(io_context) { store_info_.directory = directory; store_info_.hugepages_enabled = hugepages_enabled; #ifdef PLASMA_CUDA DCHECK_OK(CudaDeviceManager::GetInstance(&manager_)); #endif + // Start listening for clients. + DoAccept(); } // TODO(pcm): Get rid of this destructor by using RAII to clean up data. @@ -129,22 +173,12 @@ const PlasmaStoreInfo* PlasmaStore::GetPlasmaStoreInfo() { return &store_info_; // If this client is not already using the object, add the client to the // object's list of clients, otherwise do nothing. void PlasmaStore::AddToClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry, - Client* client) { + const std::shared_ptr& client) { // Check if this client is already using the object. - if (client->object_ids.find(object_id) != client->object_ids.end()) { + if (client->ObjectIDExists(object_id)) { return; } - // If there are no other clients using this object, notify the eviction policy - // that the object is being used. - if (entry->ref_count == 0) { - // Tell the eviction policy that this object is being used. - std::vector objects_to_evict; - eviction_policy_.BeginObjectAccess(object_id, &objects_to_evict); - EvictObjects(objects_to_evict); - } - // Increase reference count. - entry->ref_count++; - + IncreaseObjectRefCount(object_id, entry); // Add object id to the list of object ids that this client is using. client->object_ids.insert(object_id); } @@ -199,7 +233,8 @@ Status PlasmaStore::AllocateCudaMemory( // Create a new object buffer in the hash table. PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_size, int64_t metadata_size, int device_num, - Client* client, PlasmaObject* result) { + const std::shared_ptr& client, + PlasmaObject* result) { ARROW_LOG(DEBUG) << "creating object " << object_id.hex(); auto entry = GetObjectTableEntry(&store_info_, object_id); @@ -304,13 +339,12 @@ void PlasmaStore::RemoveGetRequest(GetRequest* get_request) { } } // Remove the get request. - if (get_request->timer != -1) { - ARROW_CHECK(loop_->RemoveTimer(get_request->timer) == kEventLoopOk); - } + get_request->CancelTimer(); delete get_request; } -void PlasmaStore::RemoveGetRequestsForClient(Client* client) { +void PlasmaStore::RemoveGetRequestsForClient( + const std::shared_ptr& client) { std::unordered_set get_requests_to_remove; for (auto const& pair : object_get_requests_) { for (GetRequest* get_request : pair.second) { @@ -329,38 +363,7 @@ void PlasmaStore::RemoveGetRequestsForClient(Client* client) { } void PlasmaStore::ReturnFromGet(GetRequest* get_req) { - // Figure out how many file descriptors we need to send. - std::unordered_set fds_to_send; - std::vector store_fds; - std::vector mmap_sizes; - for (const auto& object_id : get_req->object_ids) { - PlasmaObject& object = get_req->objects[object_id]; - int fd = object.store_fd; - if (object.data_size != -1 && fds_to_send.count(fd) == 0 && fd != -1) { - fds_to_send.insert(fd); - store_fds.push_back(fd); - mmap_sizes.push_back(GetMmapSize(fd)); - } - } - - // Send the get reply to the client. - Status s = SendGetReply(get_req->client->fd, &get_req->object_ids[0], get_req->objects, - get_req->object_ids.size(), store_fds, mmap_sizes); - WarnIfSigpipe(s.ok() ? 0 : -1, get_req->client->fd); - // If we successfully sent the get reply message to the client, then also send - // the file descriptors. - if (s.ok()) { - // Send all of the file descriptors for the present objects. - for (int store_fd : store_fds) { - // Only send the file descriptor if it hasn't been sent (see analogous - // logic in GetStoreFd in client.cc). - if (get_req->client->used_fds.find(store_fd) == get_req->client->used_fds.end()) { - WarnIfSigpipe(send_fd(get_req->client->fd, store_fd), get_req->client->fd); - get_req->client->used_fds.insert(store_fd); - } - } - } - + get_req->ReturnFromGet(); // Remove the get request from each of the relevant object_get_requests hash // tables if it is present there. It should only be present there if the get // request timed out. @@ -411,11 +414,11 @@ void PlasmaStore::UpdateObjectGetRequests(const ObjectID& object_id) { } } -void PlasmaStore::ProcessGetRequest(Client* client, - const std::vector& object_ids, - int64_t timeout_ms) { +Status PlasmaStore::ProcessGetRequest(const std::shared_ptr& client, + const std::vector& object_ids, + int64_t timeout_ms) { // Create a get request for this object. - auto get_req = new GetRequest(client, object_ids); + auto get_req = new GetRequest(io_context_, client, object_ids); std::vector evicted_ids; std::vector evicted_entries; for (auto object_id : object_ids) { @@ -492,42 +495,14 @@ void PlasmaStore::ProcessGetRequest(Client* client, } else if (timeout_ms != -1) { // Set a timer that will cause the get request to return to the client. Note // that a timeout of -1 is used to indicate that no timer should be set. - get_req->timer = loop_->AddTimer(timeout_ms, [this, get_req](int64_t timer_id) { - ReturnFromGet(get_req); - return kEventLoopTimerDone; - }); - } -} - -int PlasmaStore::RemoveFromClientObjectIds(const ObjectID& object_id, - ObjectTableEntry* entry, Client* client) { - auto it = client->object_ids.find(object_id); - if (it != client->object_ids.end()) { - client->object_ids.erase(it); - // Decrease reference count. - entry->ref_count--; - - // If no more clients are using this object, notify the eviction policy - // that the object is no longer being used. - if (entry->ref_count == 0) { - if (deletion_cache_.count(object_id) == 0) { - // Tell the eviction policy that this object is no longer being used. - std::vector objects_to_evict; - eviction_policy_.EndObjectAccess(object_id, &objects_to_evict); - EvictObjects(objects_to_evict); - } else { - // Above code does not really delete an object. Instead, it just put an - // object to LRU cache which will be cleaned when the memory is not enough. - deletion_cache_.erase(object_id); - EvictObjects({object_id}); + get_req->AsyncWait(timeout_ms, [this, get_req](const asio::error_code& ec) { + if (ec != asio::error::operation_aborted) { + // Timer was not cancelled, take necessary action. + ReturnFromGet(get_req); } - } - // Return 1 to indicate that the client was removed. - return 1; - } else { - // Return 0 to indicate that the client was not removed. - return 0; + }); } + return Status::OK(); } void PlasmaStore::EraseFromObjectTable(const ObjectID& object_id) { @@ -536,11 +511,13 @@ void PlasmaStore::EraseFromObjectTable(const ObjectID& object_id) { store_info_.objects.erase(object_id); } -void PlasmaStore::ReleaseObject(const ObjectID& object_id, Client* client) { +void PlasmaStore::ReleaseObject(const ObjectID& object_id, + const std::shared_ptr& client) { + // Remove the client from the object's array of clients. + ARROW_CHECK(client->RemoveObjectIDIfExists(object_id)); auto entry = GetObjectTableEntry(&store_info_, object_id); ARROW_CHECK(entry != nullptr); - // Remove the client from the object's array of clients. - ARROW_CHECK(RemoveFromClientObjectIds(object_id, entry, client) == 1); + DecreaseObjectRefCount(object_id, entry); } // Check if an object is present. @@ -566,31 +543,23 @@ void PlasmaStore::SealObject(const ObjectID& object_id, unsigned char digest[]) entry->construct_duration = std::time(nullptr) - entry->create_time; // Inform all subscribers that a new object has been sealed. - ObjectInfoT info; - info.object_id = object_id.binary(); - info.data_size = entry->data_size; - info.metadata_size = entry->metadata_size; - info.digest = std::string(reinterpret_cast(&digest[0]), kDigestSize); - PushNotification(&info); - + PushObjectReadyNotification(object_id, *entry); // Update all get requests that involve this object. UpdateObjectGetRequests(object_id); } -int PlasmaStore::AbortObject(const ObjectID& object_id, Client* client) { +int PlasmaStore::AbortObject(const ObjectID& object_id, + const std::shared_ptr& client) { auto entry = GetObjectTableEntry(&store_info_, object_id); ARROW_CHECK(entry != nullptr) << "To abort an object it must be in the object table."; ARROW_CHECK(entry->state != ObjectState::PLASMA_SEALED) << "To abort an object it must not have been sealed."; - auto it = client->object_ids.find(object_id); - if (it == client->object_ids.end()) { - // If the client requesting the abort is not the creator, do not - // perform the abort. - return 0; - } else { + if (client->ObjectIDExists(object_id)) { // The client requesting the abort is the creator. Free the object. EraseFromObjectTable(object_id); return 1; + } else { + return 0; } } @@ -621,11 +590,7 @@ PlasmaError PlasmaStore::DeleteObject(ObjectID& object_id) { eviction_policy_.RemoveObject(object_id); EraseFromObjectTable(object_id); // Inform all subscribers that the object has been deleted. - fb::ObjectInfoT notification; - notification.object_id = object_id.binary(); - notification.is_deletion = true; - PushNotification(¬ification); - + PushObjectDeletionNotification(object_id); return PlasmaError::OK; } @@ -656,10 +621,7 @@ void PlasmaStore::EvictObjects(const std::vector& object_ids) { // and send a deletion notification. EraseFromObjectTable(object_id); // Inform all subscribers that the object has been deleted. - fb::ObjectInfoT notification; - notification.object_id = object_id.binary(); - notification.is_deletion = true; - PushNotification(¬ification); + PushObjectDeletionNotification(object_id); } } @@ -673,33 +635,117 @@ void PlasmaStore::EvictObjects(const std::vector& object_ids) { } } -void PlasmaStore::ConnectClient(int listener_sock) { - int client_fd = AcceptClient(listener_sock); +void PlasmaStore::IncreaseObjectRefCount(const ObjectID& object_id, + ObjectTableEntry* entry) { + // If there are no other clients using this object, notify the eviction policy + // that the object is being used. + if (entry->ref_count == 0) { + // Tell the eviction policy that this object is being used. + std::vector objects_to_evict; + eviction_policy_.BeginObjectAccess(object_id, &objects_to_evict); + EvictObjects(objects_to_evict); + } + // Increase reference count. + entry->ref_count++; +} - Client* client = new Client(client_fd); - connected_clients_[client_fd] = std::unique_ptr(client); +void PlasmaStore::DecreaseObjectRefCount(const ObjectID& object_id, + ObjectTableEntry* entry) { + // Decrease reference count. + entry->ref_count--; - // Add a callback to handle events on this socket. - // TODO(pcm): Check return value. - loop_->AddFileEvent(client_fd, kEventLoopRead, [this, client](int events) { - Status s = ProcessMessage(client); - if (!s.ok()) { - ARROW_LOG(FATAL) << "Failed to process file event: " << s; + // If no more clients are using this object, notify the eviction policy + // that the object is no longer being used. + if (entry->ref_count == 0) { + if (deletion_cache_.count(object_id) == 0) { + // Tell the eviction policy that this object is no longer being used. + std::vector objects_to_evict; + eviction_policy_.EndObjectAccess(object_id, &objects_to_evict); + EvictObjects(objects_to_evict); + } else { + // Above code does not really delete an object. Instead, it just put an + // object to LRU cache which will be cleaned when the memory is not enough. + deletion_cache_.erase(object_id); + EvictObjects({object_id}); } - }); - ARROW_LOG(DEBUG) << "New connection with fd " << client_fd; + } +} + +void PlasmaStore::PushObjectReadyNotification(const ObjectID& object_id, + const ObjectTableEntry& entry) { + for (const auto& client : notification_clients_) { + client->SendObjectReadyAsync(object_id, entry); + } } -void PlasmaStore::DisconnectClient(int client_fd) { - ARROW_CHECK(client_fd > 0); - auto it = connected_clients_.find(client_fd); +void PlasmaStore::PushObjectDeletionNotification(const ObjectID& object_id) { + for (const auto& client : notification_clients_) { + client->SendObjectDeletionAsync(object_id); + } +} + +// Subscribe to notifications about sealed objects. +void PlasmaStore::SubscribeToUpdates(const std::shared_ptr& client) { + ARROW_LOG(DEBUG) << "subscribing to updates on fd " << client->GetNativeHandle(); + if (notification_clients_.count(client) > 0) { + // This client has already subscribed. Return. + return; + } + + // Add this client to the notification set, which is needed for this client to receive + // notifications. + notification_clients_.insert(client); + + // Push notifications to the new subscriber about existing sealed objects. + for (const auto& entry : store_info_.objects) { + if (entry.second->state == ObjectState::PLASMA_SEALED) { + client->SendObjectReadyAsync(entry.first, *entry.second); + } + } +} + +void PlasmaStore::DoAccept() { + // TODO(suquark): Use shared_from_this() here ? + acceptor_.async_accept(stream_, + [this](const asio::error_code& ec) { HandleAccept(ec); }); +} + +void PlasmaStore::HandleAccept(const asio::error_code& error) { + if (!error) { + io::MessageHandler message_handler = [this](std::shared_ptr client, + int64_t message_type, int64_t length, + const uint8_t* message) { + Status s = ProcessClientMessage(client, message_type, length, message); + if (!s.ok()) { + ARROW_LOG(FATAL) << "[Plasma Store] Failed to process the event" + << "(type=" << message_type << "): " << s; + } + }; + // Accept a new local client and dispatch it to the store. + auto new_connection = ClientConnection::Create(std::move(stream_), message_handler, + "plasma_store_client"); + // Insert the client before processing messages. + connected_clients_.insert(new_connection); + // Process our new connection. + new_connection->ProcessMessages(); + } + // We're ready to accept another client. + DoAccept(); +} + +void PlasmaStore::ProcessDisconnectClient( + const std::shared_ptr& client) { + ARROW_CHECK(client->IsOpen()); + auto it = connected_clients_.find(client); ARROW_CHECK(it != connected_clients_.end()); - loop_->RemoveFileEvent(client_fd); - // Close the socket. - close(client_fd); - ARROW_LOG(INFO) << "Disconnecting client on fd " << client_fd; + // Remove the client from the notification list. + if (notification_clients_.count(client) > 0) { + notification_clients_.erase(client); + } + // Close the client. + ARROW_LOG(INFO) << "Disconnecting client on fd " << client->GetNativeHandle(); + client->Close(); // Release all the objects that the client was using. - auto client = it->second.get(); std::unordered_map sealed_objects; for (const auto& object_id : client->object_ids) { auto it = store_info_.objects.find(object_id); @@ -721,193 +767,60 @@ void PlasmaStore::DisconnectClient(int client_fd) { RemoveGetRequestsForClient(client); for (const auto& entry : sealed_objects) { - RemoveFromClientObjectIds(entry.first, entry.second, client); + // The object ID must exist in client's record. + client->RemoveObjectID(entry.first); + DecreaseObjectRefCount(entry.first, entry.second); } - - if (client->notification_fd > 0) { - // This client has subscribed for notifications. - auto notify_fd = client->notification_fd; - loop_->RemoveFileEvent(notify_fd); - // Close socket. - close(notify_fd); - // Remove notification queue for this fd from global map. - pending_notifications_.erase(notify_fd); - // Reset fd. - client->notification_fd = -1; - } - connected_clients_.erase(it); } -/// Send notifications about sealed objects to the subscribers. This is called -/// in SealObject. If the socket's send buffer is full, the notification will -/// be buffered, and this will be called again when the send buffer has room. -/// Since we call erase on pending_notifications_, all iterators get -/// invalidated, which is why we return a valid iterator to the next client to -/// be used in PushNotification. -/// -/// @param it Iterator that points to the client to send the notification to. -/// @return Iterator pointing to the next client. -PlasmaStore::NotificationMap::iterator PlasmaStore::SendNotifications( - PlasmaStore::NotificationMap::iterator it) { - int client_fd = it->first; - auto& notifications = it->second.object_notifications; - - int num_processed = 0; - bool closed = false; - // Loop over the array of pending notifications and send as many of them as - // possible. - for (size_t i = 0; i < notifications.size(); ++i) { - auto& notification = notifications.at(i); - // Decode the length, which is the first bytes of the message. - int64_t size = *(reinterpret_cast(notification.get())); - - // Attempt to send a notification about this object ID. - ssize_t nbytes = send(client_fd, notification.get(), sizeof(int64_t) + size, 0); - if (nbytes >= 0) { - ARROW_CHECK(nbytes == static_cast(sizeof(int64_t)) + size); - } else if (nbytes == -1 && - (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)) { - ARROW_LOG(DEBUG) << "The socket's send buffer is full, so we are caching this " - "notification and will send it later."; - // Add a callback to the event loop to send queued notifications whenever - // there is room in the socket's send buffer. Callbacks can be added - // more than once here and will be overwritten. The callback is removed - // at the end of the method. - // TODO(pcm): Introduce status codes and check in case the file descriptor - // is added twice. - loop_->AddFileEvent(client_fd, kEventLoopWrite, [this, client_fd](int events) { - SendNotifications(pending_notifications_.find(client_fd)); - }); - break; - } else { - ARROW_LOG(WARNING) << "Failed to send notification to client on fd " << client_fd; - if (errno == EPIPE) { - closed = true; - break; - } - } - num_processed += 1; - } - // Remove the sent notifications from the array. - notifications.erase(notifications.begin(), notifications.begin() + num_processed); - - // If we have sent all notifications, remove the fd from the event loop. - if (notifications.empty()) { - loop_->RemoveFileEvent(client_fd); - } - - // Stop sending notifications if the pipe was broken. - if (closed) { - close(client_fd); - return pending_notifications_.erase(it); - } else { - return ++it; - } -} - -void PlasmaStore::PushNotification(fb::ObjectInfoT* object_info) { - auto it = pending_notifications_.begin(); - while (it != pending_notifications_.end()) { - auto notification = CreateObjectInfoBuffer(object_info); - it->second.object_notifications.emplace_back(std::move(notification)); - it = SendNotifications(it); - } -} - -void PlasmaStore::PushNotification(fb::ObjectInfoT* object_info, int client_fd) { - auto it = pending_notifications_.find(client_fd); - if (it != pending_notifications_.end()) { - auto notification = CreateObjectInfoBuffer(object_info); - it->second.object_notifications.emplace_back(std::move(notification)); - SendNotifications(it); - } -} - -// Subscribe to notifications about sealed objects. -void PlasmaStore::SubscribeToUpdates(Client* client) { - ARROW_LOG(DEBUG) << "subscribing to updates on fd " << client->fd; - if (client->notification_fd > 0) { - // This client has already subscribed. Return. - return; - } - - // TODO(rkn): The store could block here if the client doesn't send a file - // descriptor. - int fd = recv_fd(client->fd); - if (fd < 0) { - // This may mean that the client died before sending the file descriptor. - ARROW_LOG(WARNING) << "Failed to receive file descriptor from client on fd " - << client->fd << "."; - return; - } - - // Add this fd to global map, which is needed for this client to receive notifications. - pending_notifications_[fd]; - client->notification_fd = fd; - - // Push notifications to the new subscriber about existing sealed objects. - for (const auto& entry : store_info_.objects) { - if (entry.second->state == ObjectState::PLASMA_SEALED) { - ObjectInfoT info; - info.object_id = entry.first.binary(); - info.data_size = entry.second->data_size; - info.metadata_size = entry.second->metadata_size; - info.digest = - std::string(reinterpret_cast(&entry.second->digest[0]), kDigestSize); - PushNotification(&info, fd); - } - } -} - -Status PlasmaStore::ProcessMessage(Client* client) { - fb::MessageType type; - Status s = ReadMessage(client->fd, &type, &input_buffer_); - ARROW_CHECK(s.ok() || s.IsIOError()); - - uint8_t* input = input_buffer_.data(); - size_t input_size = input_buffer_.size(); +Status PlasmaStore::ProcessClientMessage(const std::shared_ptr& client, + int64_t message_type, int64_t message_size, + const uint8_t* message_data) { + auto message_type_value = static_cast(message_type); ObjectID object_id; - PlasmaObject object = {}; // Process the different types of requests. - switch (type) { - case fb::MessageType::PlasmaCreateRequest: { + switch (message_type_value) { + case MessageType::PlasmaCreateRequest: { int64_t data_size; int64_t metadata_size; int device_num; - RETURN_NOT_OK(ReadCreateRequest(input, input_size, &object_id, &data_size, + RETURN_NOT_OK(ReadCreateRequest(message_data, message_size, &object_id, &data_size, &metadata_size, &device_num)); + PlasmaObject object = {}; PlasmaError error_code = CreateObject(object_id, data_size, metadata_size, device_num, client, &object); int64_t mmap_size = 0; if (error_code == PlasmaError::OK && device_num == 0) { mmap_size = GetMmapSize(object.store_fd); } - HANDLE_SIGPIPE( - SendCreateReply(client->fd, object_id, &object, error_code, mmap_size), - client->fd); + RETURN_NOT_OK(SendCreateReply(client, object_id, &object, error_code, mmap_size)); // Only send the file descriptor if it hasn't been sent (see analogous // logic in GetStoreFd in client.cc). Similar in ReturnFromGet. if (error_code == PlasmaError::OK && device_num == 0 && client->used_fds.find(object.store_fd) == client->used_fds.end()) { - WarnIfSigpipe(send_fd(client->fd, object.store_fd), client->fd); + auto status = client->SendFd(object.store_fd); + if (!status.ok()) { + ARROW_LOG(ERROR) << "Failed to send a mmap fd to the client."; + } client->used_fds.insert(object.store_fd); } } break; - case fb::MessageType::PlasmaCreateAndSealRequest: { + case MessageType::PlasmaCreateAndSealRequest: { std::string data; std::string metadata; unsigned char digest[kDigestSize]; - RETURN_NOT_OK(ReadCreateAndSealRequest(input, input_size, &object_id, &data, - &metadata, &digest[0])); + RETURN_NOT_OK(ReadCreateAndSealRequest(message_data, message_size, &object_id, + &data, &metadata, &digest[0])); + PlasmaObject object = {}; // CreateAndSeal currently only supports device_num = 0, which corresponds // to the host. int device_num = 0; PlasmaError error_code = CreateObject(object_id, data.size(), metadata.size(), device_num, client, &object); // Reply to the client. - HANDLE_SIGPIPE(SendCreateAndSealReply(client->fd, error_code), client->fd); + RETURN_NOT_OK(SendCreateAndSealReply(client, error_code)); // If the object was successfully created, fill out the object data and seal it. if (error_code == PlasmaError::OK) { @@ -921,78 +834,77 @@ Status PlasmaStore::ProcessMessage(Client* client) { // object is not being used by any client. The client was added to the // object's array of clients in CreateObject. This is analogous to the // Release call that happens in the client's Seal method. - ARROW_CHECK(RemoveFromClientObjectIds(object_id, entry, client) == 1); + ARROW_CHECK(client->RemoveObjectIDIfExists(object_id)); + DecreaseObjectRefCount(object_id, entry); } } break; - case fb::MessageType::PlasmaAbortRequest: { - RETURN_NOT_OK(ReadAbortRequest(input, input_size, &object_id)); + case MessageType::PlasmaAbortRequest: { + RETURN_NOT_OK(ReadAbortRequest(message_data, message_size, &object_id)); ARROW_CHECK(AbortObject(object_id, client) == 1) << "To abort an object, the only " "client currently using it " "must be the creator."; - HANDLE_SIGPIPE(SendAbortReply(client->fd, object_id), client->fd); + RETURN_NOT_OK(SendAbortReply(client, object_id)); } break; - case fb::MessageType::PlasmaGetRequest: { - std::vector object_ids_to_get; + case MessageType::PlasmaGetRequest: { + std::vector object_ids; int64_t timeout_ms; - RETURN_NOT_OK(ReadGetRequest(input, input_size, object_ids_to_get, &timeout_ms)); - ProcessGetRequest(client, object_ids_to_get, timeout_ms); + RETURN_NOT_OK(ReadGetRequest(message_data, message_size, object_ids, &timeout_ms)); + RETURN_NOT_OK(ProcessGetRequest(client, object_ids, timeout_ms)); } break; - case fb::MessageType::PlasmaReleaseRequest: { - RETURN_NOT_OK(ReadReleaseRequest(input, input_size, &object_id)); + case MessageType::PlasmaReleaseRequest: { + RETURN_NOT_OK(ReadReleaseRequest(message_data, message_size, &object_id)); ReleaseObject(object_id, client); } break; - case fb::MessageType::PlasmaDeleteRequest: { + case MessageType::PlasmaDeleteRequest: { std::vector object_ids; std::vector error_codes; - RETURN_NOT_OK(ReadDeleteRequest(input, input_size, &object_ids)); + RETURN_NOT_OK(ReadDeleteRequest(message_data, message_size, &object_ids)); error_codes.reserve(object_ids.size()); for (auto& object_id : object_ids) { error_codes.push_back(DeleteObject(object_id)); } - HANDLE_SIGPIPE(SendDeleteReply(client->fd, object_ids, error_codes), client->fd); + RETURN_NOT_OK(SendDeleteReply(client, object_ids, error_codes)); } break; - case fb::MessageType::PlasmaContainsRequest: { - RETURN_NOT_OK(ReadContainsRequest(input, input_size, &object_id)); - if (ContainsObject(object_id) == ObjectStatus::OBJECT_FOUND) { - HANDLE_SIGPIPE(SendContainsReply(client->fd, object_id, 1), client->fd); - } else { - HANDLE_SIGPIPE(SendContainsReply(client->fd, object_id, 0), client->fd); - } + case MessageType::PlasmaContainsRequest: { + RETURN_NOT_OK(ReadContainsRequest(message_data, message_size, &object_id)); + auto has_object = (ContainsObject(object_id) == ObjectStatus::OBJECT_FOUND); + RETURN_NOT_OK(SendContainsReply(client, object_id, has_object)); } break; - case fb::MessageType::PlasmaListRequest: { - RETURN_NOT_OK(ReadListRequest(input, input_size)); - HANDLE_SIGPIPE(SendListReply(client->fd, store_info_.objects), client->fd); + case MessageType::PlasmaListRequest: { + RETURN_NOT_OK(ReadListRequest(message_data, message_size)); + RETURN_NOT_OK(SendListReply(client, store_info_.objects)); } break; - case fb::MessageType::PlasmaSealRequest: { + case MessageType::PlasmaSealRequest: { unsigned char digest[kDigestSize]; - RETURN_NOT_OK(ReadSealRequest(input, input_size, &object_id, &digest[0])); + RETURN_NOT_OK(ReadSealRequest(message_data, message_size, &object_id, &digest[0])); SealObject(object_id, &digest[0]); } break; - case fb::MessageType::PlasmaEvictRequest: { + case MessageType::PlasmaEvictRequest: { // This code path should only be used for testing. int64_t num_bytes; - RETURN_NOT_OK(ReadEvictRequest(input, input_size, &num_bytes)); + RETURN_NOT_OK(ReadEvictRequest(message_data, message_size, &num_bytes)); std::vector objects_to_evict; int64_t num_bytes_evicted = eviction_policy_.ChooseObjectsToEvict(num_bytes, &objects_to_evict); EvictObjects(objects_to_evict); - HANDLE_SIGPIPE(SendEvictReply(client->fd, num_bytes_evicted), client->fd); + RETURN_NOT_OK(SendEvictReply(client, num_bytes_evicted)); } break; - case fb::MessageType::PlasmaSubscribeRequest: + case MessageType::PlasmaSubscribeRequest: SubscribeToUpdates(client); break; - case fb::MessageType::PlasmaConnectRequest: { - HANDLE_SIGPIPE(SendConnectReply(client->fd, PlasmaAllocator::GetFootprintLimit()), - client->fd); + case MessageType::PlasmaConnectRequest: { + RETURN_NOT_OK(SendConnectReply(client, PlasmaAllocator::GetFootprintLimit())); } break; - case fb::MessageType::PlasmaDisconnectClient: - ARROW_LOG(DEBUG) << "Disconnecting client on fd " << client->fd; - DisconnectClient(client->fd); - break; + case MessageType::PlasmaDisconnectClient: + ARROW_LOG(DEBUG) << "Disconnecting client on fd " << client->GetNativeHandle(); + ProcessDisconnectClient(client); + return Status::OK(); // Stop listening for more messages. default: // This code should be unreachable. ARROW_CHECK(0); } + // Listen for more messages. + client->ProcessMessages(); return Status::OK(); } @@ -1000,11 +912,16 @@ class PlasmaStoreRunner { public: PlasmaStoreRunner() {} - void Start(char* socket_name, std::string directory, bool hugepages_enabled, - std::shared_ptr external_store) { + void Start(const std::string& stream_name, std::string directory, + bool hugepages_enabled, std::shared_ptr external_store) { + signal_set_.async_wait([this](std::error_code ec, int signal) { + if (signal == SIGTERM) { + ARROW_LOG(INFO) << "SIGTERM Signal received, closing Plasma Server..."; + Stop(); + } + }); // Create the event loop. - loop_.reset(new EventLoop); - store_.reset(new PlasmaStore(loop_.get(), directory, hugepages_enabled, socket_name, + store_.reset(new PlasmaStore(io_context_, directory, hugepages_enabled, stream_name, external_store)); plasma_config = store_->GetPlasmaStoreInfo(); @@ -1012,7 +929,7 @@ class PlasmaStoreRunner { // large amount of space up front. According to the documentation, // dlmalloc might need up to 128*sizeof(size_t) bytes for internal // bookkeeping. - void* pointer = plasma::PlasmaAllocator::Memalign( + void* pointer = PlasmaAllocator::Memalign( kBlockSize, PlasmaAllocator::GetFootprintLimit() - 256 * sizeof(size_t)); ARROW_CHECK(pointer != nullptr); // This will unmap the file, but the next one created will be as large @@ -1020,57 +937,33 @@ class PlasmaStoreRunner { plasma::PlasmaAllocator::Free( pointer, PlasmaAllocator::GetFootprintLimit() - 256 * sizeof(size_t)); - int socket = BindIpcSock(socket_name, true); - // TODO(pcm): Check return value. - ARROW_CHECK(socket >= 0); - - loop_->AddFileEvent(socket, kEventLoopRead, [this, socket](int events) { - this->store_->ConnectClient(socket); - }); - loop_->Start(); + io_context_.run(); } - void Stop() { loop_->Stop(); } + void Stop() { io_context_.stop(); } void Shutdown() { - loop_->Shutdown(); - loop_ = nullptr; + io_context_.stop(); store_ = nullptr; } private: - std::unique_ptr loop_; + asio::io_context io_context_; + asio::signal_set signal_set_{io_context_, SIGTERM}; + // Ignore SIGPIPE signals. If we don't do this, then when we attempt to write + // to a client that has already died, the store could die. + asio::signal_set signal_ign_{io_context_, SIGPIPE}; std::unique_ptr store_; }; static std::unique_ptr g_runner = nullptr; -void HandleSignal(int signal) { - if (signal == SIGTERM) { - ARROW_LOG(INFO) << "SIGTERM Signal received, closing Plasma Server..."; - if (g_runner != nullptr) { - g_runner->Stop(); - } - } -} - -void StartServer(char* socket_name, std::string plasma_directory, bool hugepages_enabled, - std::shared_ptr external_store) { - // Ignore SIGPIPE signals. If we don't do this, then when we attempt to write - // to a client that has already died, the store could die. - signal(SIGPIPE, SIG_IGN); - - g_runner.reset(new PlasmaStoreRunner()); - signal(SIGTERM, HandleSignal); - g_runner->Start(socket_name, plasma_directory, hugepages_enabled, external_store); -} - } // namespace plasma int main(int argc, char* argv[]) { ArrowLog::StartArrowLog(argv[0], ArrowLogLevel::ARROW_INFO); ArrowLog::InstallFailureSignalHandler(); - char* socket_name = nullptr; + char* stream_name = nullptr; // Directory where plasma memory mapped files are stored. std::string plasma_directory; std::string external_store_endpoint; @@ -1089,7 +982,7 @@ int main(int argc, char* argv[]) { hugepages_enabled = true; break; case 's': - socket_name = optarg; + stream_name = optarg; break; case 'm': { char extra; @@ -1107,7 +1000,7 @@ int main(int argc, char* argv[]) { } } // Sanity check command line options. - if (!socket_name) { + if (!stream_name) { ARROW_LOG(FATAL) << "please specify socket for incoming connections with -s switch"; } if (system_memory == -1) { @@ -1168,8 +1061,10 @@ int main(int argc, char* argv[]) { ARROW_LOG(DEBUG) << "connecting to external store..."; ARROW_CHECK_OK(external_store->Connect(external_store_endpoint)); } - ARROW_LOG(DEBUG) << "starting server listening on " << socket_name; - plasma::StartServer(socket_name, plasma_directory, hugepages_enabled, external_store); + ARROW_LOG(DEBUG) << "starting server listening on " << stream_name; + plasma::g_runner.reset(new plasma::PlasmaStoreRunner()); + plasma::g_runner->Start(stream_name, plasma_directory, hugepages_enabled, + external_store); plasma::g_runner->Shutdown(); plasma::g_runner = nullptr; diff --git a/cpp/src/plasma/store.h b/cpp/src/plasma/store.h index 53464abde8f..2c2d84a2a1c 100644 --- a/cpp/src/plasma/store.h +++ b/cpp/src/plasma/store.h @@ -18,7 +18,6 @@ #ifndef PLASMA_STORE_H #define PLASMA_STORE_H -#include #include #include #include @@ -26,11 +25,10 @@ #include #include "plasma/common.h" -#include "plasma/events.h" #include "plasma/eviction_policy.h" #include "plasma/external_store.h" +#include "plasma/io/connection.h" #include "plasma/plasma.h" -#include "plasma/protocol.h" namespace arrow { class Status; @@ -39,46 +37,18 @@ class Status; namespace plasma { namespace flatbuf { -struct ObjectInfoT; enum class PlasmaError; } // namespace flatbuf -using flatbuf::ObjectInfoT; using flatbuf::PlasmaError; +using io::ClientConnection; struct GetRequest; -struct NotificationQueue { - /// The object notifications for clients. We notify the client about the - /// objects in the order that the objects were sealed or deleted. - std::deque> object_notifications; -}; - -/// Contains all information that is associated with a Plasma store client. -struct Client { - explicit Client(int fd); - - /// The file descriptor used to communicate with the client. - int fd; - - /// Object ids that are used by this client. - std::unordered_set object_ids; - - /// File descriptors that are used by this client. - std::unordered_set used_fds; - - /// The file descriptor used to push notifications to client. This is only valid - /// if client subscribes to plasma store. -1 indicates invalid. - int notification_fd; -}; - class PlasmaStore { public: - using NotificationMap = std::unordered_map; - - // TODO: PascalCase PlasmaStore methods. - PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled, - const std::string& socket_name, + PlasmaStore(asio::io_context& main_context, std::string directory, + bool hugepages_enabled, const std::string& stream_name, std::shared_ptr external_store); ~PlasmaStore(); @@ -89,17 +59,17 @@ class PlasmaStore { /// Create a new object. The client must do a call to release_object to tell /// the store when it is done with the object. /// - /// @param object_id Object ID of the object to be created. - /// @param data_size Size in bytes of the object to be created. - /// @param metadata_size Size in bytes of the object metadata. - /// @param device_num The number of the device where the object is being + /// \param object_id Object ID of the object to be created. + /// \param data_size Size in bytes of the object to be created. + /// \param metadata_size Size in bytes of the object metadata. + /// \param device_num The number of the device where the object is being /// created. /// device_num = 0 corresponds to the host, /// device_num = 1 corresponds to GPU0, /// device_num = 2 corresponds to GPU1, etc. - /// @param client The client that created the object. - /// @param result The object that has been created. - /// @return One of the following error codes: + /// \param client The client that created the object. + /// \param result The object that has been created. + /// \return One of the following error codes: /// - PlasmaError::OK, if the object was created successfully. /// - PlasmaError::ObjectExists, if an object with this ID is already /// present in the store. In this case, the client should not call @@ -108,22 +78,24 @@ class PlasmaStore { /// cannot create the object. In this case, the client should not call /// plasma_release. PlasmaError CreateObject(const ObjectID& object_id, int64_t data_size, - int64_t metadata_size, int device_num, Client* client, + int64_t metadata_size, int device_num, + const std::shared_ptr& client, PlasmaObject* result); /// Abort a created but unsealed object. If the client is not the /// creator, then the abort will fail. /// - /// @param object_id Object ID of the object to be aborted. - /// @param client The client who created the object. If this does not + /// \param object_id Object ID of the object to be aborted. + /// \param client The client who created the object. If this does not /// match the creator of the object, then the abort will fail. - /// @return 1 if the abort succeeds, else 0. - int AbortObject(const ObjectID& object_id, Client* client); + /// \return 1 if the abort succeeds, else 0. + int AbortObject(const ObjectID& object_id, + const std::shared_ptr& client); /// Delete an specific object by object_id that have been created in the hash table. /// - /// @param object_id Object ID of the object to be deleted. - /// @return One of the following error codes: + /// \param object_id Object ID of the object to be deleted. + /// \return One of the following error codes: /// - PlasmaError::OK, if the object was delete successfully. /// - PlasmaError::ObjectNonexistent, if ths object isn't existed. /// - PlasmaError::ObjectInUse, if the object is in use. @@ -131,7 +103,7 @@ class PlasmaStore { /// Evict objects returned by the eviction policy. /// - /// @param object_ids Object IDs of the objects to be evicted. + /// \param object_ids Object IDs of the objects to be evicted. void EvictObjects(const std::vector& object_ids); /// Process a get request from a client. This method assumes that we will @@ -142,105 +114,102 @@ class PlasmaStore { /// For each object, the client must do a call to release_object to tell the /// store when it is done with the object. /// - /// @param client The client making this request. - /// @param object_ids Object IDs of the objects to be gotten. - /// @param timeout_ms The timeout for the get request in milliseconds. - void ProcessGetRequest(Client* client, const std::vector& object_ids, - int64_t timeout_ms); + /// \param client The client making this request. + /// \param object_ids Object IDs of the objects to be gotten. + /// \param timeout_ms The timeout for the get request in milliseconds. + Status ProcessGetRequest(const std::shared_ptr& client, + const std::vector& object_ids, int64_t timeout_ms); /// Seal an object. The object is now immutable and can be accessed with get. /// - /// @param object_id Object ID of the object to be sealed. - /// @param digest The digest of the object. This is used to tell if two + /// \param object_id Object ID of the object to be sealed. + /// \param digest The digest of the object. This is used to tell if two /// objects with the same object ID are the same. void SealObject(const ObjectID& object_id, unsigned char digest[]); /// Check if the plasma store contains an object: /// - /// @param object_id Object ID that will be checked. - /// @return OBJECT_FOUND if the object is in the store, OBJECT_NOT_FOUND if + /// \param object_id Object ID that will be checked. + /// \return OBJECT_FOUND if the object is in the store, OBJECT_NOT_FOUND if /// not ObjectStatus ContainsObject(const ObjectID& object_id); /// Record the fact that a particular client is no longer using an object. /// - /// @param object_id The object ID of the object that is being released. - /// @param client The client making this request. - void ReleaseObject(const ObjectID& object_id, Client* client); + /// \param object_id The object ID of the object that is being released. + /// \param client The client making this request. + void ReleaseObject(const ObjectID& object_id, + const std::shared_ptr& client); /// Subscribe a file descriptor to updates about new sealed objects. /// - /// @param client The client making this request. - void SubscribeToUpdates(Client* client); - - /// Connect a new client to the PlasmaStore. - /// - /// @param listener_sock The socket that is listening to incoming connections. - void ConnectClient(int listener_sock); - - /// Disconnect a client from the PlasmaStore. - /// - /// @param client_fd The client file descriptor that is disconnected. - void DisconnectClient(int client_fd); - - NotificationMap::iterator SendNotifications(NotificationMap::iterator it); - - arrow::Status ProcessMessage(Client* client); + /// \param client The client making this request. + void SubscribeToUpdates(const std::shared_ptr& client); private: - void PushNotification(ObjectInfoT* object_notification); + // Inform all subscribers that a new object has been sealed. + void PushObjectReadyNotification(const ObjectID& object_id, + const ObjectTableEntry& entry); - void PushNotification(ObjectInfoT* object_notification, int client_fd); + // Inform all subscribers that an object has evicted. + void PushObjectDeletionNotification(const ObjectID& object_id); void AddToClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry, - Client* client); + const std::shared_ptr& client); /// Remove a GetRequest and clean up the relevant data structures. /// - /// @param get_request The GetRequest to remove. + /// \param get_request The GetRequest to remove. void RemoveGetRequest(GetRequest* get_request); /// Remove all of the GetRequests for a given client. /// - /// @param client The client whose GetRequests should be removed. - void RemoveGetRequestsForClient(Client* client); + /// \param client The client whose GetRequests should be removed. + void RemoveGetRequestsForClient(const std::shared_ptr& client); void ReturnFromGet(GetRequest* get_req); void UpdateObjectGetRequests(const ObjectID& object_id); - int RemoveFromClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry, - Client* client); - void EraseFromObjectTable(const ObjectID& object_id); uint8_t* AllocateMemory(size_t size, int* fd, int64_t* map_size, ptrdiff_t* offset); + + void IncreaseObjectRefCount(const ObjectID& object_id, ObjectTableEntry* entry); + + void DecreaseObjectRefCount(const ObjectID& object_id, ObjectTableEntry* entry); + #ifdef PLASMA_CUDA Status AllocateCudaMemory(int device_num, int64_t size, uint8_t** out_pointer, std::shared_ptr* out_ipc_handle); #endif - /// Event loop of the plasma store. - EventLoop* loop_; + /// Accept a client connection. + void DoAccept(); + /// Handle an accepted client connection. + void HandleAccept(const asio::error_code& error); + + Status ProcessClientMessage(const std::shared_ptr& client, + int64_t message_type, int64_t message_size, + const uint8_t* message_data); + + /// Disconnect a client from the PlasmaStore. + /// + /// \param client The client that is disconnected. + void ProcessDisconnectClient(const std::shared_ptr& client); + /// The plasma store information, including the object tables, that is exposed /// to the eviction policy. PlasmaStoreInfo store_info_; /// The state that is managed by the eviction policy. EvictionPolicy eviction_policy_; - /// Input buffer. This is allocated only once to avoid mallocs for every - /// call to process_message. - std::vector input_buffer_; /// A hash table mapping object IDs to a vector of the get requests that are /// waiting for the object to arrive. std::unordered_map> object_get_requests_; - /// The pending notifications that have not been sent to subscribers because - /// the socket send buffers were full. This is a hash table from client file - /// descriptor to an array of object_ids to send to that client. - /// TODO(pcm): Consider putting this into the Client data structure and - /// reorganize the code slightly. - NotificationMap pending_notifications_; - std::unordered_map> connected_clients_; + std::unordered_set> notification_clients_; + + std::unordered_set> connected_clients_; std::unordered_set deletion_cache_; @@ -250,6 +219,13 @@ class PlasmaStore { #ifdef PLASMA_CUDA arrow::cuda::CudaDeviceManager* manager_; #endif + asio::io_context& io_context_; + /// The name of the stream this store server listens on. + std::string stream_name_; + /// An acceptor for new clients. + io::PlasmaAcceptor acceptor_; + /// The stream to listen on for new clients. + io::PlasmaStream stream_; }; } // namespace plasma diff --git a/cpp/src/plasma/test/client_tests.cc b/cpp/src/plasma/test/client_tests.cc index 4dd0c066dd4..d694758be5b 100644 --- a/cpp/src/plasma/test/client_tests.cc +++ b/cpp/src/plasma/test/client_tests.cc @@ -15,13 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include -#include -#include -#include - #include #include #include @@ -125,15 +118,13 @@ TEST_F(TestPlasmaStore, NewSubscriberTest) { ARROW_CHECK_OK(local_client.Seal(object_id)); // Test that new subscriber client2 can receive notifications about existing objects. - int fd = -1; - ARROW_CHECK_OK(local_client2.Subscribe(&fd)); - ASSERT_GT(fd, 0); + ARROW_CHECK_OK(local_client2.Subscribe()); ObjectID object_id2 = random_object_id(); int64_t data_size2 = 0; int64_t metadata_size2 = 0; ARROW_CHECK_OK( - local_client2.GetNotification(fd, &object_id2, &data_size2, &metadata_size2)); + local_client2.GetNotification(&object_id2, &data_size2, &metadata_size2)); ASSERT_EQ(object_id, object_id2); ASSERT_EQ(data_size, data_size2); ASSERT_EQ(metadata_size, metadata_size2); @@ -143,7 +134,7 @@ TEST_F(TestPlasmaStore, NewSubscriberTest) { ARROW_CHECK_OK(local_client.Delete(object_id)); ARROW_CHECK_OK( - local_client2.GetNotification(fd, &object_id2, &data_size2, &metadata_size2)); + local_client2.GetNotification(&object_id2, &data_size2, &metadata_size2)); ASSERT_EQ(object_id, object_id2); ASSERT_EQ(-1, data_size2); ASSERT_EQ(-1, metadata_size2); diff --git a/cpp/src/plasma/test/serialization_tests.cc b/cpp/src/plasma/test/serialization_tests.cc index 6c554a655bc..3cfd4c5c4c8 100644 --- a/cpp/src/plasma/test/serialization_tests.cc +++ b/cpp/src/plasma/test/serialization_tests.cc @@ -15,53 +15,56 @@ // specific language governing permissions and limitations // under the License. -#include -#include - #include +#include #include "arrow/testing/gtest_util.h" - #include "plasma/common.h" -#include "plasma/io.h" -#include "plasma/plasma.h" +#include "plasma/io/connection.h" #include "plasma/protocol.h" #include "plasma/test-util.h" -namespace fb = plasma::flatbuf; - namespace plasma { -/** - * Create a temporary file. Needs to be closed by the caller. - * - * @return File descriptor of the file. - */ -int create_temp_file(void) { - static char temp[] = "/tmp/tempfileXXXXXX"; - char file_name[32]; - strncpy(file_name, temp, 32); - return mkstemp(file_name); +using flatbuf::MessageType; +using io::ClientConnection; +using io::ServerConnection; + +class TestPlasmaSerialization : public ::testing::Test { + public: + void SetUp() override { + using asio::local::stream_protocol; + stream_protocol::socket parentSocket(io_context_); + stream_protocol::socket childSocket(io_context_); + // create socket pair + asio::local::connect_pair(childSocket, parentSocket); + client_ = ServerConnection::Create(std::move(childSocket)); + io::MessageHandler monk_handler = [](std::shared_ptr client, + int64_t type, int64_t length, + const uint8_t* msg) {}; + server_ = + ClientConnection::Create(std::move(parentSocket), monk_handler, "PlasmaClient"); + } + + void TearDown() override { + client_->Close(); + server_->Close(); + } + + protected: + asio::io_context io_context_; + std::shared_ptr client_; + std::shared_ptr server_; +}; + +Status PlasmaReceive(const std::shared_ptr& client, + MessageType message_type, std::vector* buffer) { + return client->ReadMessage(static_cast(message_type), buffer); } -/** - * Seek to the beginning of a file and read a message from it. - * - * @param fd File descriptor of the file. - * @param message_type Message type that we expect in the file. - * - * @return Pointer to the content of the message. Needs to be freed by the - * caller. - */ -std::vector read_message_from_file(int fd, MessageType message_type) { - /* Go to the beginning of the file. */ - lseek(fd, 0, SEEK_SET); - MessageType type; - std::vector data; - Status s = ReadMessage(fd, &type, &data); - DCHECK_OK(s); - DCHECK_EQ(type, message_type); - return data; +Status PlasmaReceive(const std::shared_ptr& client, + MessageType message_type, std::vector* buffer) { + return client->ReadMessage(static_cast(message_type), buffer); } PlasmaObject random_plasma_object(void) { @@ -77,15 +80,15 @@ PlasmaObject random_plasma_object(void) { return object; } -TEST(PlasmaSerialization, CreateRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, CreateRequest) { ObjectID object_id1 = random_object_id(); int64_t data_size1 = 42; int64_t metadata_size1 = 11; int device_num1 = 0; - ASSERT_OK(SendCreateRequest(fd, object_id1, data_size1, metadata_size1, device_num1)); - std::vector data = - read_message_from_file(fd, MessageType::PlasmaCreateRequest); + ASSERT_OK( + SendCreateRequest(client_, object_id1, data_size1, metadata_size1, device_num1)); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaCreateRequest, &data)); ObjectID object_id2; int64_t data_size2; int64_t metadata_size2; @@ -96,16 +99,15 @@ TEST(PlasmaSerialization, CreateRequest) { ASSERT_EQ(metadata_size1, metadata_size2); ASSERT_EQ(object_id1, object_id2); ASSERT_EQ(device_num1, device_num2); - close(fd); } -TEST(PlasmaSerialization, CreateReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, CreateReply) { ObjectID object_id1 = random_object_id(); PlasmaObject object1 = random_plasma_object(); int64_t mmap_size1 = 1000000; - ASSERT_OK(SendCreateReply(fd, object_id1, &object1, PlasmaError::OK, mmap_size1)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaCreateReply); + ASSERT_OK(SendCreateReply(server_, object_id1, &object1, PlasmaError::OK, mmap_size1)); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaCreateReply, &data)); ObjectID object_id2; PlasmaObject object2 = {}; int store_fd; @@ -116,44 +118,41 @@ TEST(PlasmaSerialization, CreateReply) { ASSERT_EQ(object1.store_fd, store_fd); ASSERT_EQ(mmap_size1, mmap_size2); ASSERT_EQ(memcmp(&object1, &object2, sizeof(object1)), 0); - close(fd); } -TEST(PlasmaSerialization, SealRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, SealRequest) { ObjectID object_id1 = random_object_id(); unsigned char digest1[kDigestSize]; memset(&digest1[0], 7, kDigestSize); - ASSERT_OK(SendSealRequest(fd, object_id1, &digest1[0])); - std::vector data = read_message_from_file(fd, MessageType::PlasmaSealRequest); + ASSERT_OK(SendSealRequest(client_, object_id1, &digest1[0])); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaSealRequest, &data)); ObjectID object_id2; unsigned char digest2[kDigestSize]; ASSERT_OK(ReadSealRequest(data.data(), data.size(), &object_id2, &digest2[0])); ASSERT_EQ(object_id1, object_id2); ASSERT_EQ(memcmp(&digest1[0], &digest2[0], kDigestSize), 0); - close(fd); } -TEST(PlasmaSerialization, SealReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, SealReply) { ObjectID object_id1 = random_object_id(); - ASSERT_OK(SendSealReply(fd, object_id1, PlasmaError::ObjectExists)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaSealReply); + ASSERT_OK(SendSealReply(server_, object_id1, PlasmaError::ObjectExists)); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaSealReply, &data)); ObjectID object_id2; Status s = ReadSealReply(data.data(), data.size(), &object_id2); ASSERT_EQ(object_id1, object_id2); ASSERT_TRUE(s.IsPlasmaObjectExists()); - close(fd); } -TEST(PlasmaSerialization, GetRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, GetRequest) { ObjectID object_ids[2]; object_ids[0] = random_object_id(); object_ids[1] = random_object_id(); int64_t timeout_ms = 1234; - ASSERT_OK(SendGetRequest(fd, object_ids, 2, timeout_ms)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaGetRequest); + ASSERT_OK(SendGetRequest(client_, object_ids, 2, timeout_ms)); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaGetRequest, &data)); std::vector object_ids_return; int64_t timeout_ms_return; ASSERT_OK( @@ -161,11 +160,9 @@ TEST(PlasmaSerialization, GetRequest) { ASSERT_EQ(object_ids[0], object_ids_return[0]); ASSERT_EQ(object_ids[1], object_ids_return[1]); ASSERT_EQ(timeout_ms, timeout_ms_return); - close(fd); } -TEST(PlasmaSerialization, GetReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, GetReply) { ObjectID object_ids[2]; object_ids[0] = random_object_id(); object_ids[1] = random_object_id(); @@ -174,9 +171,10 @@ TEST(PlasmaSerialization, GetReply) { plasma_objects[object_ids[1]] = random_plasma_object(); std::vector store_fds = {1, 2, 3}; std::vector mmap_sizes = {100, 200, 300}; - ASSERT_OK(SendGetReply(fd, object_ids, plasma_objects, 2, store_fds, mmap_sizes)); + ASSERT_OK(SendGetReply(server_, object_ids, plasma_objects, 2, store_fds, mmap_sizes)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaGetReply); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaGetReply, &data)); ObjectID object_ids_return[2]; PlasmaObject plasma_objects_return[2]; std::vector store_fds_return; @@ -196,53 +194,47 @@ TEST(PlasmaSerialization, GetReply) { 0); ASSERT_TRUE(store_fds == store_fds_return); ASSERT_TRUE(mmap_sizes == mmap_sizes_return); - close(fd); } -TEST(PlasmaSerialization, ReleaseRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, ReleaseRequest) { ObjectID object_id1 = random_object_id(); - ASSERT_OK(SendReleaseRequest(fd, object_id1)); - std::vector data = - read_message_from_file(fd, MessageType::PlasmaReleaseRequest); + ASSERT_OK(SendReleaseRequest(client_, object_id1)); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaReleaseRequest, &data)); ObjectID object_id2; ASSERT_OK(ReadReleaseRequest(data.data(), data.size(), &object_id2)); ASSERT_EQ(object_id1, object_id2); - close(fd); } -TEST(PlasmaSerialization, ReleaseReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, ReleaseReply) { ObjectID object_id1 = random_object_id(); - ASSERT_OK(SendReleaseReply(fd, object_id1, PlasmaError::ObjectExists)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaReleaseReply); + ASSERT_OK(SendReleaseReply(server_, object_id1, PlasmaError::ObjectExists)); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaReleaseReply, &data)); ObjectID object_id2; Status s = ReadReleaseReply(data.data(), data.size(), &object_id2); ASSERT_EQ(object_id1, object_id2); ASSERT_TRUE(s.IsPlasmaObjectExists()); - close(fd); } -TEST(PlasmaSerialization, DeleteRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, DeleteRequest) { ObjectID object_id1 = random_object_id(); - ASSERT_OK(SendDeleteRequest(fd, std::vector{object_id1})); - std::vector data = - read_message_from_file(fd, MessageType::PlasmaDeleteRequest); + ASSERT_OK(SendDeleteRequest(client_, std::vector{object_id1})); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaDeleteRequest, &data)); std::vector object_vec; ASSERT_OK(ReadDeleteRequest(data.data(), data.size(), &object_vec)); ASSERT_EQ(object_vec.size(), 1); ASSERT_EQ(object_id1, object_vec[0]); - close(fd); } -TEST(PlasmaSerialization, DeleteReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, DeleteReply) { ObjectID object_id1 = random_object_id(); PlasmaError error1 = PlasmaError::ObjectExists; - ASSERT_OK(SendDeleteReply(fd, std::vector{object_id1}, + ASSERT_OK(SendDeleteReply(server_, std::vector{object_id1}, std::vector{error1})); - std::vector data = read_message_from_file(fd, MessageType::PlasmaDeleteReply); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaDeleteReply, &data)); std::vector object_vec; std::vector error_vec; Status s = ReadDeleteReply(data.data(), data.size(), &object_vec, &error_vec); @@ -251,39 +243,36 @@ TEST(PlasmaSerialization, DeleteReply) { ASSERT_EQ(error_vec.size(), 1); ASSERT_TRUE(error_vec[0] == PlasmaError::ObjectExists); ASSERT_TRUE(s.ok()); - close(fd); } -TEST(PlasmaSerialization, EvictRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, EvictRequest) { int64_t num_bytes = 111; - ASSERT_OK(SendEvictRequest(fd, num_bytes)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaEvictRequest); + ASSERT_OK(SendEvictRequest(client_, num_bytes)); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaEvictRequest, &data)); int64_t num_bytes_received; ASSERT_OK(ReadEvictRequest(data.data(), data.size(), &num_bytes_received)); ASSERT_EQ(num_bytes, num_bytes_received); - close(fd); } -TEST(PlasmaSerialization, EvictReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, EvictReply) { int64_t num_bytes = 111; - ASSERT_OK(SendEvictReply(fd, num_bytes)); - std::vector data = read_message_from_file(fd, MessageType::PlasmaEvictReply); + ASSERT_OK(SendEvictReply(server_, num_bytes)); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaEvictReply, &data)); int64_t num_bytes_received; ASSERT_OK(ReadEvictReply(data.data(), data.size(), num_bytes_received)); ASSERT_EQ(num_bytes, num_bytes_received); - close(fd); } -TEST(PlasmaSerialization, DataRequest) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, DataRequest) { ObjectID object_id1 = random_object_id(); const char* address1 = "address1"; int port1 = 12345; - ASSERT_OK(SendDataRequest(fd, object_id1, address1, port1)); + ASSERT_OK(SendDataRequest(client_, object_id1, address1, port1)); /* Reading message back. */ - std::vector data = read_message_from_file(fd, MessageType::PlasmaDataRequest); + std::vector data; + ASSERT_OK(PlasmaReceive(server_, MessageType::PlasmaDataRequest, &data)); ObjectID object_id2; char* address2; int port2; @@ -292,17 +281,16 @@ TEST(PlasmaSerialization, DataRequest) { ASSERT_EQ(strcmp(address1, address2), 0); ASSERT_EQ(port1, port2); free(address2); - close(fd); } -TEST(PlasmaSerialization, DataReply) { - int fd = create_temp_file(); +TEST_F(TestPlasmaSerialization, DataReply) { ObjectID object_id1 = random_object_id(); int64_t object_size1 = 146; int64_t metadata_size1 = 198; - ASSERT_OK(SendDataReply(fd, object_id1, object_size1, metadata_size1)); + ASSERT_OK(SendDataReply(server_, object_id1, object_size1, metadata_size1)); /* Reading message back. */ - std::vector data = read_message_from_file(fd, MessageType::PlasmaDataReply); + std::vector data; + ASSERT_OK(PlasmaReceive(client_, MessageType::PlasmaDataReply, &data)); ObjectID object_id2; int64_t object_size2; int64_t metadata_size2; diff --git a/cpp/src/plasma/thirdparty/ae/ae.c b/cpp/src/plasma/thirdparty/ae/ae.c deleted file mode 100644 index dfb72244409..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae.c +++ /dev/null @@ -1,465 +0,0 @@ -/* A simple event-driven programming library. Originally I wrote this code - * for the Jim's event-loop (Jim is a Tcl interpreter) but later translated - * it in form of a library for easy reuse. - * - * Copyright (c) 2006-2010, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "plasma/thirdparty/ae/ae.h" -#include "plasma/thirdparty/ae/zmalloc.h" -#include "plasma/thirdparty/ae/config.h" - -/* Include the best multiplexing layer supported by this system. - * The following should be ordered by performances, descending. */ -#ifdef HAVE_EVPORT -#include "plasma/thirdparty/ae/ae_evport.c" -#else - #ifdef HAVE_EPOLL - #include "plasma/thirdparty/ae/ae_epoll.c" - #else - #ifdef HAVE_KQUEUE - #include "plasma/thirdparty/ae/ae_kqueue.c" - #else - #include "plasma/thirdparty/ae/ae_select.c" - #endif - #endif -#endif - -aeEventLoop *aeCreateEventLoop(int setsize) { - aeEventLoop *eventLoop; - int i; - - if ((eventLoop = zmalloc(sizeof(*eventLoop))) == NULL) goto err; - eventLoop->events = zmalloc(sizeof(aeFileEvent)*setsize); - eventLoop->fired = zmalloc(sizeof(aeFiredEvent)*setsize); - if (eventLoop->events == NULL || eventLoop->fired == NULL) goto err; - eventLoop->setsize = setsize; - eventLoop->lastTime = time(NULL); - eventLoop->timeEventHead = NULL; - eventLoop->timeEventNextId = 0; - eventLoop->stop = 0; - eventLoop->maxfd = -1; - eventLoop->beforesleep = NULL; - if (aeApiCreate(eventLoop) == -1) goto err; - /* Events with mask == AE_NONE are not set. So let's initialize the - * vector with it. */ - for (i = 0; i < setsize; i++) - eventLoop->events[i].mask = AE_NONE; - return eventLoop; - -err: - if (eventLoop) { - zfree(eventLoop->events); - zfree(eventLoop->fired); - zfree(eventLoop); - } - return NULL; -} - -/* Return the current set size. */ -int aeGetSetSize(aeEventLoop *eventLoop) { - return eventLoop->setsize; -} - -/* Resize the maximum set size of the event loop. - * If the requested set size is smaller than the current set size, but - * there is already a file descriptor in use that is >= the requested - * set size minus one, AE_ERR is returned and the operation is not - * performed at all. - * - * Otherwise AE_OK is returned and the operation is successful. */ -int aeResizeSetSize(aeEventLoop *eventLoop, int setsize) { - int i; - - if (setsize == eventLoop->setsize) return AE_OK; - if (eventLoop->maxfd >= setsize) return AE_ERR; - if (aeApiResize(eventLoop,setsize) == -1) return AE_ERR; - - eventLoop->events = zrealloc(eventLoop->events,sizeof(aeFileEvent)*setsize); - eventLoop->fired = zrealloc(eventLoop->fired,sizeof(aeFiredEvent)*setsize); - eventLoop->setsize = setsize; - - /* Make sure that if we created new slots, they are initialized with - * an AE_NONE mask. */ - for (i = eventLoop->maxfd+1; i < setsize; i++) - eventLoop->events[i].mask = AE_NONE; - return AE_OK; -} - -void aeDeleteEventLoop(aeEventLoop *eventLoop) { - aeApiFree(eventLoop); - zfree(eventLoop->events); - zfree(eventLoop->fired); - zfree(eventLoop); -} - -void aeStop(aeEventLoop *eventLoop) { - eventLoop->stop = 1; -} - -int aeCreateFileEvent(aeEventLoop *eventLoop, int fd, int mask, - aeFileProc *proc, void *clientData) -{ - if (fd >= eventLoop->setsize) { - errno = ERANGE; - return AE_ERR; - } - aeFileEvent *fe = &eventLoop->events[fd]; - - if (aeApiAddEvent(eventLoop, fd, mask) == -1) - return AE_ERR; - fe->mask |= mask; - if (mask & AE_READABLE) fe->rfileProc = proc; - if (mask & AE_WRITABLE) fe->wfileProc = proc; - fe->clientData = clientData; - if (fd > eventLoop->maxfd) - eventLoop->maxfd = fd; - return AE_OK; -} - -void aeDeleteFileEvent(aeEventLoop *eventLoop, int fd, int mask) -{ - if (fd >= eventLoop->setsize) return; - aeFileEvent *fe = &eventLoop->events[fd]; - if (fe->mask == AE_NONE) return; - - aeApiDelEvent(eventLoop, fd, mask); - fe->mask = fe->mask & (~mask); - if (fd == eventLoop->maxfd && fe->mask == AE_NONE) { - /* Update the max fd */ - int j; - - for (j = eventLoop->maxfd-1; j >= 0; j--) - if (eventLoop->events[j].mask != AE_NONE) break; - eventLoop->maxfd = j; - } -} - -int aeGetFileEvents(aeEventLoop *eventLoop, int fd) { - if (fd >= eventLoop->setsize) return 0; - aeFileEvent *fe = &eventLoop->events[fd]; - - return fe->mask; -} - -static void aeGetTime(long *seconds, long *milliseconds) -{ - struct timeval tv; - - gettimeofday(&tv, NULL); - *seconds = tv.tv_sec; - *milliseconds = tv.tv_usec/1000; -} - -static void aeAddMillisecondsToNow(long long milliseconds, long *sec, long *ms) { - long cur_sec, cur_ms, when_sec, when_ms; - - aeGetTime(&cur_sec, &cur_ms); - when_sec = cur_sec + milliseconds/1000; - when_ms = cur_ms + milliseconds%1000; - if (when_ms >= 1000) { - when_sec ++; - when_ms -= 1000; - } - *sec = when_sec; - *ms = when_ms; -} - -long long aeCreateTimeEvent(aeEventLoop *eventLoop, long long milliseconds, - aeTimeProc *proc, void *clientData, - aeEventFinalizerProc *finalizerProc) -{ - long long id = eventLoop->timeEventNextId++; - aeTimeEvent *te; - - te = zmalloc(sizeof(*te)); - if (te == NULL) return AE_ERR; - te->id = id; - aeAddMillisecondsToNow(milliseconds,&te->when_sec,&te->when_ms); - te->timeProc = proc; - te->finalizerProc = finalizerProc; - te->clientData = clientData; - te->next = eventLoop->timeEventHead; - eventLoop->timeEventHead = te; - return id; -} - -int aeDeleteTimeEvent(aeEventLoop *eventLoop, long long id) -{ - aeTimeEvent *te = eventLoop->timeEventHead; - while(te) { - if (te->id == id) { - te->id = AE_DELETED_EVENT_ID; - return AE_OK; - } - te = te->next; - } - return AE_ERR; /* NO event with the specified ID found */ -} - -/* Search the first timer to fire. - * This operation is useful to know how many time the select can be - * put in sleep without to delay any event. - * If there are no timers NULL is returned. - * - * Note that's O(N) since time events are unsorted. - * Possible optimizations (not needed by Redis so far, but...): - * 1) Insert the event in order, so that the nearest is just the head. - * Much better but still insertion or deletion of timers is O(N). - * 2) Use a skiplist to have this operation as O(1) and insertion as O(log(N)). - */ -static aeTimeEvent *aeSearchNearestTimer(aeEventLoop *eventLoop) -{ - aeTimeEvent *te = eventLoop->timeEventHead; - aeTimeEvent *nearest = NULL; - - while(te) { - if (!nearest || te->when_sec < nearest->when_sec || - (te->when_sec == nearest->when_sec && - te->when_ms < nearest->when_ms)) - nearest = te; - te = te->next; - } - return nearest; -} - -/* Process time events */ -static int processTimeEvents(aeEventLoop *eventLoop) { - int processed = 0; - aeTimeEvent *te, *prev; - long long maxId; - time_t now = time(NULL); - - /* If the system clock is moved to the future, and then set back to the - * right value, time events may be delayed in a random way. Often this - * means that scheduled operations will not be performed soon enough. - * - * Here we try to detect system clock skews, and force all the time - * events to be processed ASAP when this happens: the idea is that - * processing events earlier is less dangerous than delaying them - * indefinitely, and practice suggests it is. */ - if (now < eventLoop->lastTime) { - te = eventLoop->timeEventHead; - while(te) { - te->when_sec = 0; - te = te->next; - } - } - eventLoop->lastTime = now; - - prev = NULL; - te = eventLoop->timeEventHead; - maxId = eventLoop->timeEventNextId-1; - while(te) { - long now_sec, now_ms; - long long id; - - /* Remove events scheduled for deletion. */ - if (te->id == AE_DELETED_EVENT_ID) { - aeTimeEvent *next = te->next; - if (prev == NULL) - eventLoop->timeEventHead = te->next; - else - prev->next = te->next; - if (te->finalizerProc) - te->finalizerProc(eventLoop, te->clientData); - zfree(te); - te = next; - continue; - } - - /* Make sure we don't process time events created by time events in - * this iteration. Note that this check is currently useless: we always - * add new timers on the head, however if we change the implementation - * detail, this check may be useful again: we keep it here for future - * defense. */ - if (te->id > maxId) { - te = te->next; - continue; - } - aeGetTime(&now_sec, &now_ms); - if (now_sec > te->when_sec || - (now_sec == te->when_sec && now_ms >= te->when_ms)) - { - int retval; - - id = te->id; - retval = te->timeProc(eventLoop, id, te->clientData); - processed++; - if (retval != AE_NOMORE) { - aeAddMillisecondsToNow(retval,&te->when_sec,&te->when_ms); - } else { - te->id = AE_DELETED_EVENT_ID; - } - } - prev = te; - te = te->next; - } - return processed; -} - -/* Process every pending time event, then every pending file event - * (that may be registered by time event callbacks just processed). - * Without special flags the function sleeps until some file event - * fires, or when the next time event occurs (if any). - * - * If flags is 0, the function does nothing and returns. - * if flags has AE_ALL_EVENTS set, all the kind of events are processed. - * if flags has AE_FILE_EVENTS set, file events are processed. - * if flags has AE_TIME_EVENTS set, time events are processed. - * if flags has AE_DONT_WAIT set the function returns ASAP until all - * the events that's possible to process without to wait are processed. - * - * The function returns the number of events processed. */ -int aeProcessEvents(aeEventLoop *eventLoop, int flags) -{ - int processed = 0, numevents; - - /* Nothing to do? return ASAP */ - if (!(flags & AE_TIME_EVENTS) && !(flags & AE_FILE_EVENTS)) return 0; - - /* Note that we want call select() even if there are no - * file events to process as long as we want to process time - * events, in order to sleep until the next time event is ready - * to fire. */ - if (eventLoop->maxfd != -1 || - ((flags & AE_TIME_EVENTS) && !(flags & AE_DONT_WAIT))) { - int j; - aeTimeEvent *shortest = NULL; - struct timeval tv, *tvp; - - if (flags & AE_TIME_EVENTS && !(flags & AE_DONT_WAIT)) - shortest = aeSearchNearestTimer(eventLoop); - if (shortest) { - long now_sec, now_ms; - - aeGetTime(&now_sec, &now_ms); - tvp = &tv; - - /* How many milliseconds we need to wait for the next - * time event to fire? */ - long long ms = - (shortest->when_sec - now_sec)*1000 + - shortest->when_ms - now_ms; - - if (ms > 0) { - tvp->tv_sec = ms/1000; - tvp->tv_usec = (ms % 1000)*1000; - } else { - tvp->tv_sec = 0; - tvp->tv_usec = 0; - } - } else { - /* If we have to check for events but need to return - * ASAP because of AE_DONT_WAIT we need to set the timeout - * to zero */ - if (flags & AE_DONT_WAIT) { - tv.tv_sec = tv.tv_usec = 0; - tvp = &tv; - } else { - /* Otherwise we can block */ - tvp = NULL; /* wait forever */ - } - } - - numevents = aeApiPoll(eventLoop, tvp); - for (j = 0; j < numevents; j++) { - aeFileEvent *fe = &eventLoop->events[eventLoop->fired[j].fd]; - int mask = eventLoop->fired[j].mask; - int fd = eventLoop->fired[j].fd; - int rfired = 0; - - /* note the fe->mask & mask & ... code: maybe an already processed - * event removed an element that fired and we still didn't - * processed, so we check if the event is still valid. */ - if (fe->mask & mask & AE_READABLE) { - rfired = 1; - fe->rfileProc(eventLoop,fd,fe->clientData,mask); - } - if (fe->mask & mask & AE_WRITABLE) { - if (!rfired || fe->wfileProc != fe->rfileProc) - fe->wfileProc(eventLoop,fd,fe->clientData,mask); - } - processed++; - } - } - /* Check time events */ - if (flags & AE_TIME_EVENTS) - processed += processTimeEvents(eventLoop); - - return processed; /* return the number of processed file/time events */ -} - -/* Wait for milliseconds until the given file descriptor becomes - * writable/readable/exception */ -int aeWait(int fd, int mask, long long milliseconds) { - struct pollfd pfd; - int retmask = 0, retval; - - memset(&pfd, 0, sizeof(pfd)); - pfd.fd = fd; - if (mask & AE_READABLE) pfd.events |= POLLIN; - if (mask & AE_WRITABLE) pfd.events |= POLLOUT; - - if ((retval = poll(&pfd, 1, milliseconds))== 1) { - if (pfd.revents & POLLIN) retmask |= AE_READABLE; - if (pfd.revents & POLLOUT) retmask |= AE_WRITABLE; - if (pfd.revents & POLLERR) retmask |= AE_WRITABLE; - if (pfd.revents & POLLHUP) retmask |= AE_WRITABLE; - return retmask; - } else { - return retval; - } -} - -void aeMain(aeEventLoop *eventLoop) { - eventLoop->stop = 0; - while (!eventLoop->stop) { - if (eventLoop->beforesleep != NULL) - eventLoop->beforesleep(eventLoop); - aeProcessEvents(eventLoop, AE_ALL_EVENTS); - } -} - -char *aeGetApiName(void) { - return aeApiName(); -} - -void aeSetBeforeSleepProc(aeEventLoop *eventLoop, aeBeforeSleepProc *beforesleep) { - eventLoop->beforesleep = beforesleep; -} diff --git a/cpp/src/plasma/thirdparty/ae/ae.h b/cpp/src/plasma/thirdparty/ae/ae.h deleted file mode 100644 index 827c4c9e4e5..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae.h +++ /dev/null @@ -1,123 +0,0 @@ -/* A simple event-driven programming library. Originally I wrote this code - * for the Jim's event-loop (Jim is a Tcl interpreter) but later translated - * it in form of a library for easy reuse. - * - * Copyright (c) 2006-2012, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __AE_H__ -#define __AE_H__ - -#include - -#define AE_OK 0 -#define AE_ERR -1 - -#define AE_NONE 0 -#define AE_READABLE 1 -#define AE_WRITABLE 2 - -#define AE_FILE_EVENTS 1 -#define AE_TIME_EVENTS 2 -#define AE_ALL_EVENTS (AE_FILE_EVENTS|AE_TIME_EVENTS) -#define AE_DONT_WAIT 4 - -#define AE_NOMORE -1 -#define AE_DELETED_EVENT_ID -1 - -/* Macros */ -#define AE_NOTUSED(V) ((void) V) - -struct aeEventLoop; - -/* Types and data structures */ -typedef void aeFileProc(struct aeEventLoop *eventLoop, int fd, void *clientData, int mask); -typedef int aeTimeProc(struct aeEventLoop *eventLoop, long long id, void *clientData); -typedef void aeEventFinalizerProc(struct aeEventLoop *eventLoop, void *clientData); -typedef void aeBeforeSleepProc(struct aeEventLoop *eventLoop); - -/* File event structure */ -typedef struct aeFileEvent { - int mask; /* one of AE_(READABLE|WRITABLE) */ - aeFileProc *rfileProc; - aeFileProc *wfileProc; - void *clientData; -} aeFileEvent; - -/* Time event structure */ -typedef struct aeTimeEvent { - long long id; /* time event identifier. */ - long when_sec; /* seconds */ - long when_ms; /* milliseconds */ - aeTimeProc *timeProc; - aeEventFinalizerProc *finalizerProc; - void *clientData; - struct aeTimeEvent *next; -} aeTimeEvent; - -/* A fired event */ -typedef struct aeFiredEvent { - int fd; - int mask; -} aeFiredEvent; - -/* State of an event based program */ -typedef struct aeEventLoop { - int maxfd; /* highest file descriptor currently registered */ - int setsize; /* max number of file descriptors tracked */ - long long timeEventNextId; - time_t lastTime; /* Used to detect system clock skew */ - aeFileEvent *events; /* Registered events */ - aeFiredEvent *fired; /* Fired events */ - aeTimeEvent *timeEventHead; - int stop; - void *apidata; /* This is used for polling API specific data */ - aeBeforeSleepProc *beforesleep; -} aeEventLoop; - -/* Prototypes */ -aeEventLoop *aeCreateEventLoop(int setsize); -void aeDeleteEventLoop(aeEventLoop *eventLoop); -void aeStop(aeEventLoop *eventLoop); -int aeCreateFileEvent(aeEventLoop *eventLoop, int fd, int mask, - aeFileProc *proc, void *clientData); -void aeDeleteFileEvent(aeEventLoop *eventLoop, int fd, int mask); -int aeGetFileEvents(aeEventLoop *eventLoop, int fd); -long long aeCreateTimeEvent(aeEventLoop *eventLoop, long long milliseconds, - aeTimeProc *proc, void *clientData, - aeEventFinalizerProc *finalizerProc); -int aeDeleteTimeEvent(aeEventLoop *eventLoop, long long id); -int aeProcessEvents(aeEventLoop *eventLoop, int flags); -int aeWait(int fd, int mask, long long milliseconds); -void aeMain(aeEventLoop *eventLoop); -char *aeGetApiName(void); -void aeSetBeforeSleepProc(aeEventLoop *eventLoop, aeBeforeSleepProc *beforesleep); -int aeGetSetSize(aeEventLoop *eventLoop); -int aeResizeSetSize(aeEventLoop *eventLoop, int setsize); - -#endif diff --git a/cpp/src/plasma/thirdparty/ae/ae_epoll.c b/cpp/src/plasma/thirdparty/ae/ae_epoll.c deleted file mode 100644 index 2f70550a980..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae_epoll.c +++ /dev/null @@ -1,137 +0,0 @@ -/* Linux epoll(2) based ae.c module - * - * Copyright (c) 2009-2012, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - - -#include - -typedef struct aeApiState { - int epfd; - struct epoll_event *events; -} aeApiState; - -static int aeApiCreate(aeEventLoop *eventLoop) { - aeApiState *state = zmalloc(sizeof(aeApiState)); - - if (!state) return -1; - state->events = zmalloc(sizeof(struct epoll_event)*eventLoop->setsize); - if (!state->events) { - zfree(state); - return -1; - } - state->epfd = epoll_create(1024); /* 1024 is just a hint for the kernel */ - if (state->epfd == -1) { - zfree(state->events); - zfree(state); - return -1; - } - eventLoop->apidata = state; - return 0; -} - -static int aeApiResize(aeEventLoop *eventLoop, int setsize) { - aeApiState *state = eventLoop->apidata; - - state->events = zrealloc(state->events, sizeof(struct epoll_event)*setsize); - return 0; -} - -static void aeApiFree(aeEventLoop *eventLoop) { - aeApiState *state = eventLoop->apidata; - - close(state->epfd); - zfree(state->events); - zfree(state); -} - -static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - struct epoll_event ee; - memset(&ee, 0, sizeof(struct epoll_event)); // avoid valgrind warning - /* If the fd was already monitored for some event, we need a MOD - * operation. Otherwise we need an ADD operation. */ - int op = eventLoop->events[fd].mask == AE_NONE ? - EPOLL_CTL_ADD : EPOLL_CTL_MOD; - - ee.events = 0; - mask |= eventLoop->events[fd].mask; /* Merge old events */ - if (mask & AE_READABLE) ee.events |= EPOLLIN; - if (mask & AE_WRITABLE) ee.events |= EPOLLOUT; - ee.data.fd = fd; - if (epoll_ctl(state->epfd,op,fd,&ee) == -1) return -1; - return 0; -} - -static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int delmask) { - aeApiState *state = eventLoop->apidata; - struct epoll_event ee; - memset(&ee, 0, sizeof(struct epoll_event)); // avoid valgrind warning - int mask = eventLoop->events[fd].mask & (~delmask); - - ee.events = 0; - if (mask & AE_READABLE) ee.events |= EPOLLIN; - if (mask & AE_WRITABLE) ee.events |= EPOLLOUT; - ee.data.fd = fd; - if (mask != AE_NONE) { - epoll_ctl(state->epfd,EPOLL_CTL_MOD,fd,&ee); - } else { - /* Note, Kernel < 2.6.9 requires a non null event pointer even for - * EPOLL_CTL_DEL. */ - epoll_ctl(state->epfd,EPOLL_CTL_DEL,fd,&ee); - } -} - -static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) { - aeApiState *state = eventLoop->apidata; - int retval, numevents = 0; - - retval = epoll_wait(state->epfd,state->events,eventLoop->setsize, - tvp ? (tvp->tv_sec*1000 + tvp->tv_usec/1000) : -1); - if (retval > 0) { - int j; - - numevents = retval; - for (j = 0; j < numevents; j++) { - int mask = 0; - struct epoll_event *e = state->events+j; - - if (e->events & EPOLLIN) mask |= AE_READABLE; - if (e->events & EPOLLOUT) mask |= AE_WRITABLE; - if (e->events & EPOLLERR) mask |= AE_WRITABLE; - if (e->events & EPOLLHUP) mask |= AE_WRITABLE; - eventLoop->fired[j].fd = e->data.fd; - eventLoop->fired[j].mask = mask; - } - } - return numevents; -} - -static char *aeApiName(void) { - return "epoll"; -} diff --git a/cpp/src/plasma/thirdparty/ae/ae_evport.c b/cpp/src/plasma/thirdparty/ae/ae_evport.c deleted file mode 100644 index 5c317becb6f..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae_evport.c +++ /dev/null @@ -1,320 +0,0 @@ -/* ae.c module for illumos event ports. - * - * Copyright (c) 2012, Joyent, Inc. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - - -#include -#include -#include -#include - -#include -#include - -#include - -static int evport_debug = 0; - -/* - * This file implements the ae API using event ports, present on Solaris-based - * systems since Solaris 10. Using the event port interface, we associate file - * descriptors with the port. Each association also includes the set of poll(2) - * events that the consumer is interested in (e.g., POLLIN and POLLOUT). - * - * There's one tricky piece to this implementation: when we return events via - * aeApiPoll, the corresponding file descriptors become dissociated from the - * port. This is necessary because poll events are level-triggered, so if the - * fd didn't become dissociated, it would immediately fire another event since - * the underlying state hasn't changed yet. We must re-associate the file - * descriptor, but only after we know that our caller has actually read from it. - * The ae API does not tell us exactly when that happens, but we do know that - * it must happen by the time aeApiPoll is called again. Our solution is to - * keep track of the last fds returned by aeApiPoll and re-associate them next - * time aeApiPoll is invoked. - * - * To summarize, in this module, each fd association is EITHER (a) represented - * only via the in-kernel association OR (b) represented by pending_fds and - * pending_masks. (b) is only true for the last fds we returned from aeApiPoll, - * and only until we enter aeApiPoll again (at which point we restore the - * in-kernel association). - */ -#define MAX_EVENT_BATCHSZ 512 - -typedef struct aeApiState { - int portfd; /* event port */ - int npending; /* # of pending fds */ - int pending_fds[MAX_EVENT_BATCHSZ]; /* pending fds */ - int pending_masks[MAX_EVENT_BATCHSZ]; /* pending fds' masks */ -} aeApiState; - -static int aeApiCreate(aeEventLoop *eventLoop) { - int i; - aeApiState *state = zmalloc(sizeof(aeApiState)); - if (!state) return -1; - - state->portfd = port_create(); - if (state->portfd == -1) { - zfree(state); - return -1; - } - - state->npending = 0; - - for (i = 0; i < MAX_EVENT_BATCHSZ; i++) { - state->pending_fds[i] = -1; - state->pending_masks[i] = AE_NONE; - } - - eventLoop->apidata = state; - return 0; -} - -static int aeApiResize(aeEventLoop *eventLoop, int setsize) { - /* Nothing to resize here. */ - return 0; -} - -static void aeApiFree(aeEventLoop *eventLoop) { - aeApiState *state = eventLoop->apidata; - - close(state->portfd); - zfree(state); -} - -static int aeApiLookupPending(aeApiState *state, int fd) { - int i; - - for (i = 0; i < state->npending; i++) { - if (state->pending_fds[i] == fd) - return (i); - } - - return (-1); -} - -/* - * Helper function to invoke port_associate for the given fd and mask. - */ -static int aeApiAssociate(const char *where, int portfd, int fd, int mask) { - int events = 0; - int rv, err; - - if (mask & AE_READABLE) - events |= POLLIN; - if (mask & AE_WRITABLE) - events |= POLLOUT; - - if (evport_debug) - fprintf(stderr, "%s: port_associate(%d, 0x%x) = ", where, fd, events); - - rv = port_associate(portfd, PORT_SOURCE_FD, fd, events, - (void *)(uintptr_t)mask); - err = errno; - - if (evport_debug) - fprintf(stderr, "%d (%s)\n", rv, rv == 0 ? "no error" : strerror(err)); - - if (rv == -1) { - fprintf(stderr, "%s: port_associate: %s\n", where, strerror(err)); - - if (err == EAGAIN) - fprintf(stderr, "aeApiAssociate: event port limit exceeded."); - } - - return rv; -} - -static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - int fullmask, pfd; - - if (evport_debug) - fprintf(stderr, "aeApiAddEvent: fd %d mask 0x%x\n", fd, mask); - - /* - * Since port_associate's "events" argument replaces any existing events, we - * must be sure to include whatever events are already associated when - * we call port_associate() again. - */ - fullmask = mask | eventLoop->events[fd].mask; - pfd = aeApiLookupPending(state, fd); - - if (pfd != -1) { - /* - * This fd was recently returned from aeApiPoll. It should be safe to - * assume that the consumer has processed that poll event, but we play - * it safer by simply updating pending_mask. The fd will be - * re-associated as usual when aeApiPoll is called again. - */ - if (evport_debug) - fprintf(stderr, "aeApiAddEvent: adding to pending fd %d\n", fd); - state->pending_masks[pfd] |= fullmask; - return 0; - } - - return (aeApiAssociate("aeApiAddEvent", state->portfd, fd, fullmask)); -} - -static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - int fullmask, pfd; - - if (evport_debug) - fprintf(stderr, "del fd %d mask 0x%x\n", fd, mask); - - pfd = aeApiLookupPending(state, fd); - - if (pfd != -1) { - if (evport_debug) - fprintf(stderr, "deleting event from pending fd %d\n", fd); - - /* - * This fd was just returned from aeApiPoll, so it's not currently - * associated with the port. All we need to do is update - * pending_mask appropriately. - */ - state->pending_masks[pfd] &= ~mask; - - if (state->pending_masks[pfd] == AE_NONE) - state->pending_fds[pfd] = -1; - - return; - } - - /* - * The fd is currently associated with the port. Like with the add case - * above, we must look at the full mask for the file descriptor before - * updating that association. We don't have a good way of knowing what the - * events are without looking into the eventLoop state directly. We rely on - * the fact that our caller has already updated the mask in the eventLoop. - */ - - fullmask = eventLoop->events[fd].mask; - if (fullmask == AE_NONE) { - /* - * We're removing *all* events, so use port_dissociate to remove the - * association completely. Failure here indicates a bug. - */ - if (evport_debug) - fprintf(stderr, "aeApiDelEvent: port_dissociate(%d)\n", fd); - - if (port_dissociate(state->portfd, PORT_SOURCE_FD, fd) != 0) { - perror("aeApiDelEvent: port_dissociate"); - abort(); /* will not return */ - } - } else if (aeApiAssociate("aeApiDelEvent", state->portfd, fd, - fullmask) != 0) { - /* - * ENOMEM is a potentially transient condition, but the kernel won't - * generally return it unless things are really bad. EAGAIN indicates - * we've reached an resource limit, for which it doesn't make sense to - * retry (counter-intuitively). All other errors indicate a bug. In any - * of these cases, the best we can do is to abort. - */ - abort(); /* will not return */ - } -} - -static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) { - aeApiState *state = eventLoop->apidata; - struct timespec timeout, *tsp; - int mask, i; - uint_t nevents; - port_event_t event[MAX_EVENT_BATCHSZ]; - - /* - * If we've returned fd events before, we must re-associate them with the - * port now, before calling port_get(). See the block comment at the top of - * this file for an explanation of why. - */ - for (i = 0; i < state->npending; i++) { - if (state->pending_fds[i] == -1) - /* This fd has since been deleted. */ - continue; - - if (aeApiAssociate("aeApiPoll", state->portfd, - state->pending_fds[i], state->pending_masks[i]) != 0) { - /* See aeApiDelEvent for why this case is fatal. */ - abort(); - } - - state->pending_masks[i] = AE_NONE; - state->pending_fds[i] = -1; - } - - state->npending = 0; - - if (tvp != NULL) { - timeout.tv_sec = tvp->tv_sec; - timeout.tv_nsec = tvp->tv_usec * 1000; - tsp = &timeout; - } else { - tsp = NULL; - } - - /* - * port_getn can return with errno == ETIME having returned some events (!). - * So if we get ETIME, we check nevents, too. - */ - nevents = 1; - if (port_getn(state->portfd, event, MAX_EVENT_BATCHSZ, &nevents, - tsp) == -1 && (errno != ETIME || nevents == 0)) { - if (errno == ETIME || errno == EINTR) - return 0; - - /* Any other error indicates a bug. */ - perror("aeApiPoll: port_get"); - abort(); - } - - state->npending = nevents; - - for (i = 0; i < nevents; i++) { - mask = 0; - if (event[i].portev_events & POLLIN) - mask |= AE_READABLE; - if (event[i].portev_events & POLLOUT) - mask |= AE_WRITABLE; - - eventLoop->fired[i].fd = event[i].portev_object; - eventLoop->fired[i].mask = mask; - - if (evport_debug) - fprintf(stderr, "aeApiPoll: fd %d mask 0x%x\n", - (int)event[i].portev_object, mask); - - state->pending_fds[i] = event[i].portev_object; - state->pending_masks[i] = (uintptr_t)event[i].portev_user; - } - - return nevents; -} - -static char *aeApiName(void) { - return "evport"; -} diff --git a/cpp/src/plasma/thirdparty/ae/ae_kqueue.c b/cpp/src/plasma/thirdparty/ae/ae_kqueue.c deleted file mode 100644 index 6796f4ceb59..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae_kqueue.c +++ /dev/null @@ -1,138 +0,0 @@ -/* Kqueue(2)-based ae.c module - * - * Copyright (C) 2009 Harish Mallipeddi - harish.mallipeddi@gmail.com - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - - -#include -#include -#include - -typedef struct aeApiState { - int kqfd; - struct kevent *events; -} aeApiState; - -static int aeApiCreate(aeEventLoop *eventLoop) { - aeApiState *state = zmalloc(sizeof(aeApiState)); - - if (!state) return -1; - state->events = zmalloc(sizeof(struct kevent)*eventLoop->setsize); - if (!state->events) { - zfree(state); - return -1; - } - state->kqfd = kqueue(); - if (state->kqfd == -1) { - zfree(state->events); - zfree(state); - return -1; - } - eventLoop->apidata = state; - return 0; -} - -static int aeApiResize(aeEventLoop *eventLoop, int setsize) { - aeApiState *state = eventLoop->apidata; - - state->events = zrealloc(state->events, sizeof(struct kevent)*setsize); - return 0; -} - -static void aeApiFree(aeEventLoop *eventLoop) { - aeApiState *state = eventLoop->apidata; - - close(state->kqfd); - zfree(state->events); - zfree(state); -} - -static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - struct kevent ke; - - if (mask & AE_READABLE) { - EV_SET(&ke, fd, EVFILT_READ, EV_ADD, 0, 0, NULL); - if (kevent(state->kqfd, &ke, 1, NULL, 0, NULL) == -1) return -1; - } - if (mask & AE_WRITABLE) { - EV_SET(&ke, fd, EVFILT_WRITE, EV_ADD, 0, 0, NULL); - if (kevent(state->kqfd, &ke, 1, NULL, 0, NULL) == -1) return -1; - } - return 0; -} - -static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - struct kevent ke; - - if (mask & AE_READABLE) { - EV_SET(&ke, fd, EVFILT_READ, EV_DELETE, 0, 0, NULL); - kevent(state->kqfd, &ke, 1, NULL, 0, NULL); - } - if (mask & AE_WRITABLE) { - EV_SET(&ke, fd, EVFILT_WRITE, EV_DELETE, 0, 0, NULL); - kevent(state->kqfd, &ke, 1, NULL, 0, NULL); - } -} - -static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) { - aeApiState *state = eventLoop->apidata; - int retval, numevents = 0; - - if (tvp != NULL) { - struct timespec timeout; - timeout.tv_sec = tvp->tv_sec; - timeout.tv_nsec = tvp->tv_usec * 1000; - retval = kevent(state->kqfd, NULL, 0, state->events, eventLoop->setsize, - &timeout); - } else { - retval = kevent(state->kqfd, NULL, 0, state->events, eventLoop->setsize, - NULL); - } - - if (retval > 0) { - int j; - - numevents = retval; - for(j = 0; j < numevents; j++) { - int mask = 0; - struct kevent *e = state->events+j; - - if (e->filter == EVFILT_READ) mask |= AE_READABLE; - if (e->filter == EVFILT_WRITE) mask |= AE_WRITABLE; - eventLoop->fired[j].fd = e->ident; - eventLoop->fired[j].mask = mask; - } - } - return numevents; -} - -static char *aeApiName(void) { - return "kqueue"; -} diff --git a/cpp/src/plasma/thirdparty/ae/ae_select.c b/cpp/src/plasma/thirdparty/ae/ae_select.c deleted file mode 100644 index c039a8ea312..00000000000 --- a/cpp/src/plasma/thirdparty/ae/ae_select.c +++ /dev/null @@ -1,106 +0,0 @@ -/* Select()-based ae.c module. - * - * Copyright (c) 2009-2012, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - - -#include -#include - -typedef struct aeApiState { - fd_set rfds, wfds; - /* We need to have a copy of the fd sets as it's not safe to reuse - * FD sets after select(). */ - fd_set _rfds, _wfds; -} aeApiState; - -static int aeApiCreate(aeEventLoop *eventLoop) { - aeApiState *state = zmalloc(sizeof(aeApiState)); - - if (!state) return -1; - FD_ZERO(&state->rfds); - FD_ZERO(&state->wfds); - eventLoop->apidata = state; - return 0; -} - -static int aeApiResize(aeEventLoop *eventLoop, int setsize) { - /* Just ensure we have enough room in the fd_set type. */ - if (setsize >= FD_SETSIZE) return -1; - return 0; -} - -static void aeApiFree(aeEventLoop *eventLoop) { - zfree(eventLoop->apidata); -} - -static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - - if (mask & AE_READABLE) FD_SET(fd,&state->rfds); - if (mask & AE_WRITABLE) FD_SET(fd,&state->wfds); - return 0; -} - -static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) { - aeApiState *state = eventLoop->apidata; - - if (mask & AE_READABLE) FD_CLR(fd,&state->rfds); - if (mask & AE_WRITABLE) FD_CLR(fd,&state->wfds); -} - -static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) { - aeApiState *state = eventLoop->apidata; - int retval, j, numevents = 0; - - memcpy(&state->_rfds,&state->rfds,sizeof(fd_set)); - memcpy(&state->_wfds,&state->wfds,sizeof(fd_set)); - - retval = select(eventLoop->maxfd+1, - &state->_rfds,&state->_wfds,NULL,tvp); - if (retval > 0) { - for (j = 0; j <= eventLoop->maxfd; j++) { - int mask = 0; - aeFileEvent *fe = &eventLoop->events[j]; - - if (fe->mask == AE_NONE) continue; - if (fe->mask & AE_READABLE && FD_ISSET(j,&state->_rfds)) - mask |= AE_READABLE; - if (fe->mask & AE_WRITABLE && FD_ISSET(j,&state->_wfds)) - mask |= AE_WRITABLE; - eventLoop->fired[numevents].fd = j; - eventLoop->fired[numevents].mask = mask; - numevents++; - } - } - return numevents; -} - -static char *aeApiName(void) { - return "select"; -} diff --git a/cpp/src/plasma/thirdparty/ae/config.h b/cpp/src/plasma/thirdparty/ae/config.h deleted file mode 100644 index 4f8e1ea1bc3..00000000000 --- a/cpp/src/plasma/thirdparty/ae/config.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2009-2012, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __CONFIG_H -#define __CONFIG_H - -#ifdef __APPLE__ -#include -#endif - -/* Test for polling API */ -#ifdef __linux__ -#define HAVE_EPOLL 1 -#endif - -#if (defined(__APPLE__) && defined(MAC_OS_X_VERSION_10_6)) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined (__NetBSD__) -#define HAVE_KQUEUE 1 -#endif - -#ifdef __sun -#include -#ifdef _DTRACE_VERSION -#define HAVE_EVPORT 1 -#endif -#endif - - -#endif diff --git a/cpp/src/plasma/thirdparty/ae/zmalloc.h b/cpp/src/plasma/thirdparty/ae/zmalloc.h deleted file mode 100644 index 6c27dd4e5c3..00000000000 --- a/cpp/src/plasma/thirdparty/ae/zmalloc.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2009-2012, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef _ZMALLOC_H -#define _ZMALLOC_H - -#ifndef zmalloc -#define zmalloc malloc -#endif - -#ifndef zfree -#define zfree free -#endif - -#ifndef zrealloc -#define zrealloc realloc -#endif - -#endif /* _ZMALLOC_H */ diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 80e87415de0..f59beebe6e0 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -22,14 +22,6 @@ cpp/cmake_modules/FindPythonLibsNew.cmake cpp/cmake_modules/SnappyCMakeLists.txt cpp/cmake_modules/SnappyConfig.h cpp/src/parquet/.parquetcppversion -cpp/src/plasma/thirdparty/ae/ae.c -cpp/src/plasma/thirdparty/ae/ae.h -cpp/src/plasma/thirdparty/ae/ae_epoll.c -cpp/src/plasma/thirdparty/ae/ae_evport.c -cpp/src/plasma/thirdparty/ae/ae_kqueue.c -cpp/src/plasma/thirdparty/ae/ae_select.c -cpp/src/plasma/thirdparty/ae/config.h -cpp/src/plasma/thirdparty/ae/zmalloc.h cpp/src/plasma/thirdparty/dlmalloc.c dev/release/rat_exclude_files.txt dev/tasks/linux-packages/debian.ubuntu-trusty/compat diff --git a/python/pyarrow/_plasma.pyx b/python/pyarrow/_plasma.pyx index 04963b63165..df96e745fa7 100644 --- a/python/pyarrow/_plasma.pyx +++ b/python/pyarrow/_plasma.pyx @@ -122,15 +122,17 @@ cdef extern from "plasma/client.h" nogil: CStatus List(CObjectTable* objects) - CStatus Subscribe(int* fd) + CStatus Subscribe() CStatus DecodeNotification(const uint8_t* buffer, CUniqueID* object_id, int64_t* data_size, int64_t* metadata_size) - CStatus GetNotification(int fd, CUniqueID* object_id, + CStatus GetNotification(CUniqueID* object_id, int64_t* data_size, int64_t* metadata_size) + int GetNativeNotificationHandle() + CStatus Disconnect() CStatus Delete(const c_vector[CUniqueID] object_ids) @@ -265,12 +267,10 @@ cdef class PlasmaClient: cdef: shared_ptr[CPlasmaClient] client - int notification_fd c_string store_socket_name def __cinit__(self): self.client.reset(new CPlasmaClient()) - self.notification_fd = -1 self.store_socket_name = b"" cdef _get_object_buffers(self, object_ids, int64_t timeout_ms, @@ -622,13 +622,14 @@ cdef class PlasmaClient: def subscribe(self): """Subscribe to notifications about sealed objects.""" with nogil: - check_status(self.client.get().Subscribe(&self.notification_fd)) + check_status(self.client.get().Subscribe()) def get_notification_socket(self): """ Get the notification socket. """ - return compat.get_socket_from_fd(self.notification_fd, + cdef int fd = self.client.get().GetNativeNotificationHandle() + return compat.get_socket_from_fd(fd, family=socket.AF_UNIX, type=socket.SOCK_STREAM) @@ -674,8 +675,7 @@ cdef class PlasmaClient: cdef int64_t metadata_size with nogil: check_status(self.client.get() - .GetNotification(self.notification_fd, - &object_id.data, + .GetNotification(&object_id.data, &data_size, &metadata_size)) return object_id, data_size, metadata_size From 040e40c04cb9b07bd557858363c653428dcb998c Mon Sep 17 00:00:00 2001 From: suquark Date: Fri, 22 Mar 2019 15:41:46 +0800 Subject: [PATCH 02/52] fix according to github comments --- cpp/src/plasma/io/basic_connection.cc | 12 ++++- cpp/src/plasma/io/basic_connection.h | 19 +++++-- cpp/src/plasma/io/connection.cc | 63 +++++++++++----------- cpp/src/plasma/io/connection.h | 14 ++--- cpp/src/plasma/store.cc | 3 +- cpp/src/plasma/test/serialization_tests.cc | 3 +- 6 files changed, 61 insertions(+), 53 deletions(-) diff --git a/cpp/src/plasma/io/basic_connection.cc b/cpp/src/plasma/io/basic_connection.cc index ce5ea607910..1ba9571733b 100644 --- a/cpp/src/plasma/io/basic_connection.cc +++ b/cpp/src/plasma/io/basic_connection.cc @@ -91,7 +91,7 @@ template Connection::~Connection() { // If there are any pending messages, invoke their callbacks with an IOError status. for (const auto& write_buffer : async_write_queue_) { - write_buffer->handler( + write_buffer->Handle( std::error_code(static_cast(std::errc::io_error), std::system_category())); } } @@ -224,15 +224,23 @@ void Connection::DoAsyncWrites() { [this, this_ptr, num_messages](const std::error_code& ec, size_t bytes_transferred) { bytes_written_ += bytes_transferred; + bool close_connection = false; // Call the handlers for the written messages. for (int i = 0; i < num_messages; i++) { auto write_buffer = std::move(async_write_queue_.front()); - write_buffer->handler(ec); + auto return_code = write_buffer->Handle(ec); + if (return_code != AsyncWriteCallbackCode::OK) { + close_connection = true; + } async_write_queue_.pop_front(); // release object } // We finished writing, so mark that we're no longer doing an // async write. async_write_in_flight_ = false; + if (close_connection) { + Close(); + return; + } // If there is more to write, try to write the rest. if (!async_write_queue_.empty()) { DoAsyncWrites(); diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index 32f8445e867..bb878c3e733 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -34,7 +34,13 @@ namespace plasma { namespace io { -using AsyncWriteCallback = std::function; +enum class AsyncWriteCallbackCode { + OK, + DISCONNECT, + UNKNOWN_ERROR, +}; + +using AsyncWriteCallback = std::function; // TODO(suquark): Change it according to the platform. using PlasmaStream = asio::basic_stream_socket; using PlasmaAcceptor = asio::local::stream_protocol::acceptor; @@ -47,9 +53,14 @@ PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& /// A message that is queued for writing asynchronously. struct AsyncWriteBuffer { - virtual void ToBuffers(std::vector& message_buffers) {} - AsyncWriteCallback handler; - virtual ~AsyncWriteBuffer() {} + virtual void ToBuffers(std::vector& message_buffers) = 0; + virtual ~AsyncWriteBuffer(){}; + inline AsyncWriteCallbackCode Handle(const std::error_code& ec) { + return handler_(ec); + }; + + protected: + AsyncWriteCallback handler_; }; template diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc index 64a42b376bc..9e6ea91991a 100644 --- a/cpp/src/plasma/io/connection.cc +++ b/cpp/src/plasma/io/connection.cc @@ -17,7 +17,6 @@ #include "plasma/io/connection.h" -#include #include #include #include @@ -30,7 +29,7 @@ // TODO(pcm): Replace our own custom message header (message type, // message length, plasma protocol verion) with one that is serialized // using flatbuffers. -constexpr int64_t kPlasmaProtocolVersion = 0x0000000000000000; +constexpr int64_t kPlasmaProtocolVersion = 0x504C41534D410000; // PLASMA\0\0 namespace plasma { namespace io { @@ -55,7 +54,7 @@ struct AsyncMessageWriteBuffer : public AsyncWriteBuffer { : write_version(version), write_type(type), write_length(length) { write_message.resize(length); write_message.assign(message, message + length); - handler = callback; + AsyncWriteBuffer::handler_ = callback; } void ToBuffers(std::vector& message_buffers) override { @@ -132,7 +131,7 @@ Status ServerConnection::WriteMessage(int64_t type, int64_t length, void ServerConnection::WriteMessageAsync(int64_t type, int64_t length, const uint8_t* message, const AsyncWriteCallback& handler) { - auto write_buffer = std::unique_ptr(new AsyncMessageWriteBuffer( + auto write_buffer = std::unique_ptr(new AsyncMessageWriteBuffer( kPlasmaProtocolVersion, type, length, message, handler)); PlasmaConnection::WriteBufferAsync(std::move(write_buffer)); } @@ -177,17 +176,13 @@ ServerConnection::ServerConnection(PlasmaStream&& stream) : PlasmaConnection(std::move(stream)) {} std::shared_ptr ClientConnection::Create( - PlasmaStream&& stream, MessageHandler& message_handler, - const std::string& debug_label) { + PlasmaStream&& stream, MessageHandler& message_handler) { return std::shared_ptr( - new ClientConnection(std::move(stream), message_handler, debug_label)); + new ClientConnection(std::move(stream), message_handler)); } -ClientConnection::ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler, - const std::string& debug_label) - : ServerConnection(std::move(stream)), - debug_label_(debug_label), - message_handler_(message_handler) {} +ClientConnection::ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler) + : ServerConnection(std::move(stream)), message_handler_(message_handler) {} std::shared_ptr ClientConnection::shared_from_this() { return std::static_pointer_cast(ServerConnection::shared_from_this()); @@ -206,10 +201,11 @@ void ClientConnection::ProcessMessages() { std::placeholders::_1)); // Ignore byte_transferred } -void ClientConnection::ProcessMessageHeader(const std::error_code& error) { - if (error) { +void ClientConnection::ProcessMessageHeader(const std::error_code& ec) { + auto status = asio_to_arrow_status(ec); + if (!status.ok()) { // If there was an error, disconnect the client. - ProcessError(error); + ProcessError(status); return; } @@ -225,24 +221,20 @@ void ClientConnection::ProcessMessageHeader(const std::error_code& error) { std::placeholders::_1)); } -void ClientConnection::ProcessMessageBody(const std::error_code& error) { - if (error) { - ProcessError(error); +void ClientConnection::ProcessMessageBody(const std::error_code& ec) { + auto status = asio_to_arrow_status(ec); + if (!status.ok()) { + // If there was an error, disconnect the client. + ProcessError(status); return; } - auto start = std::chrono::system_clock::now(); + ProcessMessage(read_type_, read_length_, read_message_.data()); - auto end = std::chrono::system_clock::now(); - auto interval = std::chrono::duration(end - start); - if (interval.count() > 100.0) { - ARROW_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " - << read_type_ << " took " << interval.count() << " ms."; - } } -void ClientConnection::ProcessError(const std::error_code& ec) { - ARROW_LOG(ERROR) - << "Failed when processing message. Disconnecting the client. Error code = " << ec; +void ClientConnection::ProcessError(const Status& status) { + ARROW_LOG(ERROR) << "Failed when processing message. Disconnecting the client. (" + << status << ")"; // If there was an error, disconnect the client. PlasmaConnection::Close(); } @@ -252,6 +244,8 @@ void ClientConnection::ProcessMessage(int64_t type, int64_t length, const uint8_ } struct AsyncObjectNotificationWriteBuffer : public AsyncWriteBuffer { + ~AsyncObjectNotificationWriteBuffer() override {} + static std::unique_ptr MakeDeletion( const ObjectID& object_id) { auto message = new std::vector(); @@ -281,20 +275,23 @@ struct AsyncObjectNotificationWriteBuffer : public AsyncWriteBuffer { // Serialize the object. notification_msg.reset(message); size = message->size(); - handler = [](const asio::error_code& status) { + AsyncWriteBuffer::handler_ = + [](const asio::error_code& status) -> AsyncWriteCallbackCode { auto errno_ = status.value(); if (!errno_) { - return; + return AsyncWriteCallbackCode::OK; } if (errno_ == EAGAIN || errno_ == EWOULDBLOCK || errno_ == EINTR) { ARROW_LOG(DEBUG) << "The socket's send buffer is full, so we are caching this " "notification and will send it later."; ARROW_LOG(WARNING) << "Blocked unexpectly when sending message async."; + return AsyncWriteCallbackCode::OK; } else { ARROW_LOG(WARNING) << "Failed to send notification to client."; if (errno_ == EPIPE) { - // TODO(suquark): We could probably close the socket here. + return AsyncWriteCallbackCode::DISCONNECT; } + return AsyncWriteCallbackCode::UNKNOWN_ERROR; } }; } @@ -308,7 +305,7 @@ Status ClientConnection::SendFd(int fd) { void ClientConnection::SendObjectDeletionAsync(const ObjectID& object_id) { auto raw_ptr = AsyncObjectNotificationWriteBuffer::MakeDeletion(object_id).release(); auto write_buffer = - std::unique_ptr(static_cast(raw_ptr)); + std::unique_ptr(static_cast(raw_ptr)); // Attempt to send a notification about this object ID. WriteBufferAsync(std::move(write_buffer)); } @@ -318,7 +315,7 @@ void ClientConnection::SendObjectReadyAsync(const ObjectID& object_id, auto raw_ptr = AsyncObjectNotificationWriteBuffer::MakeReady(object_id, entry).release(); auto write_buffer = - std::unique_ptr(static_cast(raw_ptr)); + std::unique_ptr(static_cast(raw_ptr)); // Attempt to send a notification about this object ID. WriteBufferAsync(std::move(write_buffer)); } diff --git a/cpp/src/plasma/io/connection.h b/cpp/src/plasma/io/connection.h index a87a124211a..716bc6611e5 100644 --- a/cpp/src/plasma/io/connection.h +++ b/cpp/src/plasma/io/connection.h @@ -99,11 +99,9 @@ class ClientConnection : public ServerConnection { /// /// \param stream The client stream. /// \param message_handler A reference to the message handler. - /// \param debug_label The debug label. /// \return std::shared_ptr. static std::shared_ptr Create(PlasmaStream&& stream, - MessageHandler& message_handler, - const std::string& debug_label); + MessageHandler& message_handler); std::shared_ptr shared_from_this(); @@ -160,15 +158,11 @@ class ClientConnection : public ServerConnection { void ProcessMessage(int64_t type, int64_t length, const uint8_t* data); /// Process an error from reading the message from the client. - /// \param ec The returned error code. - void ProcessError(const std::error_code& ec); + /// \param status The status code. + void ProcessError(const Status& status); /// A private constructor for a node client connection. - ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler, - const std::string& debug_label); - - /// A label used for debug messages. - const std::string debug_label_; + ClientConnection(PlasmaStream&& stream, MessageHandler& message_handler); /// The handler for a message from the client. MessageHandler message_handler_; diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index b83697c5ff5..49aeb00ee1f 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -722,8 +722,7 @@ void PlasmaStore::HandleAccept(const asio::error_code& error) { } }; // Accept a new local client and dispatch it to the store. - auto new_connection = ClientConnection::Create(std::move(stream_), message_handler, - "plasma_store_client"); + auto new_connection = ClientConnection::Create(std::move(stream_), message_handler); // Insert the client before processing messages. connected_clients_.insert(new_connection); // Process our new connection. diff --git a/cpp/src/plasma/test/serialization_tests.cc b/cpp/src/plasma/test/serialization_tests.cc index 3cfd4c5c4c8..1ad8fb4393a 100644 --- a/cpp/src/plasma/test/serialization_tests.cc +++ b/cpp/src/plasma/test/serialization_tests.cc @@ -42,8 +42,7 @@ class TestPlasmaSerialization : public ::testing::Test { io::MessageHandler monk_handler = [](std::shared_ptr client, int64_t type, int64_t length, const uint8_t* msg) {}; - server_ = - ClientConnection::Create(std::move(parentSocket), monk_handler, "PlasmaClient"); + server_ = ClientConnection::Create(std::move(parentSocket), monk_handler); } void TearDown() override { From d6ea8666d0be24cf16ec3ef64101a38714f903b5 Mon Sep 17 00:00:00 2001 From: suquark Date: Fri, 22 Mar 2019 17:44:13 +0800 Subject: [PATCH 03/52] lint --- cpp/src/plasma/io/basic_connection.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index bb878c3e733..c121dedefcf 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -54,10 +54,8 @@ PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& /// A message that is queued for writing asynchronously. struct AsyncWriteBuffer { virtual void ToBuffers(std::vector& message_buffers) = 0; - virtual ~AsyncWriteBuffer(){}; - inline AsyncWriteCallbackCode Handle(const std::error_code& ec) { - return handler_(ec); - }; + virtual ~AsyncWriteBuffer() {} + inline AsyncWriteCallbackCode Handle(const std::error_code& ec) { return handler_(ec); } protected: AsyncWriteCallback handler_; From 168dea4857b8091e91d005b5fe2a1de56a5da1e4 Mon Sep 17 00:00:00 2001 From: suquark Date: Fri, 22 Mar 2019 19:14:40 +0800 Subject: [PATCH 04/52] prevent the store from dying --- cpp/src/plasma/io/connection.cc | 23 ++++++++++++++++++++--- cpp/src/plasma/store.cc | 19 +++++++------------ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc index 9e6ea91991a..4cd4aa4b401 100644 --- a/cpp/src/plasma/io/connection.cc +++ b/cpp/src/plasma/io/connection.cc @@ -210,8 +210,13 @@ void ClientConnection::ProcessMessageHeader(const std::error_code& ec) { } // If there was no error, make sure the protocol version matches. - // TODO(suquark): Don't let server die here. - ARROW_CHECK(read_version_ == kPlasmaProtocolVersion); + if (read_version_ != kPlasmaProtocolVersion) { + status = Status::ProtocolError( + "Expected Plasma message protocol version: ", kPlasmaProtocolVersion, + ", got protocol version: ", read_version_); + ProcessError(status); + return; + } // Resize the message buffer to match the received length. read_message_.resize(read_length_); ServerConnection::bytes_read_ += read_length_; @@ -298,7 +303,19 @@ struct AsyncObjectNotificationWriteBuffer : public AsyncWriteBuffer { }; Status ClientConnection::SendFd(int fd) { - ARROW_CHECK(send_fd(GetNativeHandle(), fd)); + // Only send the file descriptor if it hasn't been sent (see analogous + // logic in GetStoreFd in client.cc). + if (used_fds.find(fd) == used_fds.end()) { + auto ec = send_fd(GetNativeHandle(), fd); + if (ec <= 0) { + if (ec == 0) { + return Status::IOError("Encountered unexpected EOF"); + } else { + return Status::IOError("Unknown I/O Error"); + } + } + used_fds.insert(fd); // Succeed, record the fd. + } return Status::OK(); } diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index 49aeb00ee1f..44fba5413f7 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -106,14 +106,10 @@ struct GetRequest { if (s.ok()) { // Send all of the file descriptors for the present objects. for (int store_fd : store_fds) { - // Only send the file descriptor if it hasn't been sent (see analogous - // logic in GetStoreFd in client.cc). - if (client->used_fds.find(store_fd) == client->used_fds.end()) { - auto status = client->SendFd(store_fd); - if (!status.ok()) { - ARROW_LOG(ERROR) << "Failed to send a mmap fd to client"; - } - client->used_fds.insert(store_fd); + auto status = client->SendFd(store_fd); + if (!status.ok()) { + // TODO(suquark): Should we close the client here? + ARROW_LOG(ERROR) << "Failed to send a mmap fd to client"; } } } @@ -797,13 +793,12 @@ Status PlasmaStore::ProcessClientMessage(const std::shared_ptr RETURN_NOT_OK(SendCreateReply(client, object_id, &object, error_code, mmap_size)); // Only send the file descriptor if it hasn't been sent (see analogous // logic in GetStoreFd in client.cc). Similar in ReturnFromGet. - if (error_code == PlasmaError::OK && device_num == 0 && - client->used_fds.find(object.store_fd) == client->used_fds.end()) { + if (error_code == PlasmaError::OK && device_num == 0) { auto status = client->SendFd(object.store_fd); if (!status.ok()) { - ARROW_LOG(ERROR) << "Failed to send a mmap fd to the client."; + // TODO(suquark): Should we close the client here? + ARROW_LOG(ERROR) << "Failed to send a mmap fd to client"; } - client->used_fds.insert(object.store_fd); } } break; case MessageType::PlasmaCreateAndSealRequest: { From 323f6339e093fb543f9fdf380f3d923503b8037a Mon Sep 17 00:00:00 2001 From: suquark Date: Wed, 27 Mar 2019 15:55:02 +0800 Subject: [PATCH 05/52] fix ARROW_CHECK --- cpp/src/plasma/client.cc | 63 +++++++++++++++++++++------------ cpp/src/plasma/io/connection.cc | 4 ++- cpp/src/plasma/store.cc | 53 ++++++++++++++++++--------- cpp/src/plasma/store.h | 5 +++ 4 files changed, 85 insertions(+), 40 deletions(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index f4e713afb9d..58aaee960b6 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -248,8 +248,9 @@ class PlasmaClient::Impl : public std::enable_shared_from_thisRecvFd(&fd); - ARROW_CHECK(status.ok() && fd >= 0) << "recv not successful"; - return fd; + RETURN_NOT_OK(store_conn_->RecvFd(fd)); } else { - return entry->second->fd(); + *fd = entry->second->fd(); } + return Status::OK(); } void PlasmaClient::Impl::IncrementObjectCount(const ObjectID& object_id, @@ -398,7 +397,8 @@ Status PlasmaClient::Impl::Create(const ObjectID& object_id, int64_t data_size, // If the CreateReply included an error, then the store will not send a file // descriptor. if (device_num == 0) { - int fd = GetStoreFd(store_fd); + int fd; + RETURN_NOT_OK(GetStoreFd(store_fd, &fd)); ARROW_CHECK(object.data_size == data_size); ARROW_CHECK(object.metadata_size == metadata_size); // The metadata should come right after the data. @@ -484,8 +484,11 @@ Status PlasmaClient::Impl::GetBuffers( // This client created the object but hasn't sealed it. If we call Get // with no timeout, we will deadlock, because this client won't be able to // call Seal. - ARROW_CHECK(timeout_ms != -1) - << "Plasma client called get on an unsealed object that it created"; + if (timeout_ms != -1) { + return Status::Invalid( + "Plasma client called get on an" + " unsealed object that it created"); + } ARROW_LOG(WARNING) << "Attempting to get an object that this client created but hasn't sealed."; all_present = false; @@ -536,7 +539,8 @@ Status PlasmaClient::Impl::GetBuffers( // in the subsequent loop based on just the store file descriptor and without // having to know the relevant file descriptor received from recv_fd. for (size_t i = 0; i < store_fds.size(); i++) { - int fd = GetStoreFd(store_fds[i]); + int fd; + RETURN_NOT_OK(GetStoreFd(store_fds[i], &fd)); LookupOrMmap(fd, store_fds[i], mmap_sizes[i]); } @@ -631,11 +635,14 @@ Status PlasmaClient::Impl::Release(const ObjectID& object_id) { return Status::OK(); } auto object_entry = objects_in_use_.find(object_id); - ARROW_CHECK(object_entry != objects_in_use_.end()); - object_entry->second->count -= 1; - ARROW_CHECK(object_entry->second->count >= 0); + if (object_entry == objects_in_use_.end()) { + return Status::Invalid("Trying to release a non-existing object."); + } + auto& entry = *object_entry->second; + entry.count -= 1; + ARROW_CHECK(entry.count >= 0) << "Got negative ref count."; // Check if the client is no longer using this object. - if (object_entry->second->count == 0) { + if (entry.count == 0) { // Tell the store that the client no longer needs the object. RETURN_NOT_OK(MarkObjectUnused(object_id)); RETURN_NOT_OK(SendReleaseRequest(store_conn_, object_id)); @@ -773,22 +780,29 @@ Status PlasmaClient::Impl::Seal(const ObjectID& object_id) { Status PlasmaClient::Impl::Abort(const ObjectID& object_id) { auto object_entry = objects_in_use_.find(object_id); - ARROW_CHECK(object_entry != objects_in_use_.end()) - << "Plasma client called abort on an object without a reference to it"; - ARROW_CHECK(!object_entry->second->is_sealed) - << "Plasma client called abort on a sealed object"; + if (object_entry == objects_in_use_.end()) { + return Status::Invalid( + "Plasma client called abort on " + "an object without a reference to it"); + } + + auto& entry = *object_entry->second; + + if (entry.is_sealed) { + return Status::Invalid("Plasma client called abort on a sealed object"); + } // Make sure that the Plasma client only has one reference to the object. If // it has more, then the client needs to release the buffer before calling // abort. - if (object_entry->second->count > 1) { + if (entry.count > 1) { return Status::Invalid("Plasma client cannot have a reference to the buffer."); } // Send the abort request. RETURN_NOT_OK(SendAbortRequest(store_conn_, object_id)); // Decrease the reference count to zero, then remove the object. - object_entry->second->count--; + entry.count--; RETURN_NOT_OK(MarkObjectUnused(object_id)); std::vector buffer; @@ -860,7 +874,12 @@ Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* o int64_t* data_size, int64_t* metadata_size) { auto object_info = flatbuffers::GetRoot(buffer); - ARROW_CHECK(object_info->object_id()->size() == sizeof(ObjectID)); + if (object_info->object_id()->size() != sizeof(ObjectID)) { + return Status::Invalid( + "The size of ObjectID in the message is different from the size " + "of ObjectID in Plasma. The message could have been corrupt."); + } + memcpy(object_id, object_info->object_id()->data(), sizeof(ObjectID)); if (object_info->is_deletion()) { *data_size = -1; diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc index 4cd4aa4b401..4823385db78 100644 --- a/cpp/src/plasma/io/connection.cc +++ b/cpp/src/plasma/io/connection.cc @@ -138,7 +138,9 @@ void ServerConnection::WriteMessageAsync(int64_t type, int64_t length, Status ServerConnection::RecvFd(int* fd) { *fd = recv_fd(GetNativeHandle()); - ARROW_CHECK(*fd); + if (*fd < 0) { + return Status::Invalid("Got an invalid fd."); + } return Status::OK(); } diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index 44fba5413f7..39168fb65b0 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -713,8 +713,9 @@ void PlasmaStore::HandleAccept(const asio::error_code& error) { const uint8_t* message) { Status s = ProcessClientMessage(client, message_type, length, message); if (!s.ok()) { - ARROW_LOG(FATAL) << "[Plasma Store] Failed to process the event" - << "(type=" << message_type << "): " << s; + ARROW_LOG(ERROR) << "[PlasmaStore] Failed to process the event" + << "(type=" << message_type << "): " << s << ", " + << "fd = " << client->GetNativeHandle(); } }; // Accept a new local client and dispatch it to the store. @@ -728,18 +729,8 @@ void PlasmaStore::HandleAccept(const asio::error_code& error) { DoAccept(); } -void PlasmaStore::ProcessDisconnectClient( +void PlasmaStore::ReleaseClientResources( const std::shared_ptr& client) { - ARROW_CHECK(client->IsOpen()); - auto it = connected_clients_.find(client); - ARROW_CHECK(it != connected_clients_.end()); - // Remove the client from the notification list. - if (notification_clients_.count(client) > 0) { - notification_clients_.erase(client); - } - // Close the client. - ARROW_LOG(INFO) << "Disconnecting client on fd " << client->GetNativeHandle(); - client->Close(); // Release all the objects that the client was using. std::unordered_map sealed_objects; for (const auto& object_id : client->object_ids) { @@ -766,7 +757,33 @@ void PlasmaStore::ProcessDisconnectClient( client->RemoveObjectID(entry.first); DecreaseObjectRefCount(entry.first, entry.second); } +} + +void PlasmaStore::ProcessDisconnectClient( + const std::shared_ptr& client) { + if (!client->IsOpen()) { + ARROW_LOG(ERROR) << "Received disconnection request from a disconnected client."; + return; + } + // Close the client. + ARROW_LOG(INFO) << "Disconnecting client on fd " << client->GetNativeHandle(); + client->Close(); + + // Remove the client from the connection set. + auto it = connected_clients_.find(client); + if (it == connected_clients_.end()) { + ARROW_LOG(FATAL) << "[PlasmaStore] (on DisconnectClient) Unexpected error: The " + << "client to disconnect is not in the connected clients list."; + return; + } connected_clients_.erase(it); + // Remove the client from the notification set. + if (notification_clients_.count(client) > 0) { + notification_clients_.erase(client); + } + + // Release resources. + ReleaseClientResources(client); } Status PlasmaStore::ProcessClientMessage(const std::shared_ptr& client, @@ -797,7 +814,8 @@ Status PlasmaStore::ProcessClientMessage(const std::shared_ptr auto status = client->SendFd(object.store_fd); if (!status.ok()) { // TODO(suquark): Should we close the client here? - ARROW_LOG(ERROR) << "Failed to send a mmap fd to client"; + ARROW_LOG(ERROR) << "[PlasmaStore] (on CreateRequest) Failed to send a mmap fd" + << " to the client."; } } } break; @@ -834,9 +852,10 @@ Status PlasmaStore::ProcessClientMessage(const std::shared_ptr } break; case MessageType::PlasmaAbortRequest: { RETURN_NOT_OK(ReadAbortRequest(message_data, message_size, &object_id)); - ARROW_CHECK(AbortObject(object_id, client) == 1) << "To abort an object, the only " - "client currently using it " - "must be the creator."; + if (AbortObject(object_id, client) != 1) { + ARROW_LOG(ERROR) << "[PlasmaStore] (on AbortRequest) To abort an object, the " + << "only client currently using it must be the creator."; + } RETURN_NOT_OK(SendAbortReply(client, object_id)); } break; case MessageType::PlasmaGetRequest: { diff --git a/cpp/src/plasma/store.h b/cpp/src/plasma/store.h index 2c2d84a2a1c..6ce6605298d 100644 --- a/cpp/src/plasma/store.h +++ b/cpp/src/plasma/store.h @@ -167,6 +167,11 @@ class PlasmaStore { /// \param client The client whose GetRequests should be removed. void RemoveGetRequestsForClient(const std::shared_ptr& client); + /// Release all resources used by the client. + /// + /// \param client The client whose resources should be released. + void ReleaseClientResources(const std::shared_ptr& client); + void ReturnFromGet(GetRequest* get_req); void UpdateObjectGetRequests(const ObjectID& object_id); From 0350d413c83aeb62e49bba59aa3764b24a38b7dc Mon Sep 17 00:00:00 2001 From: suquark Date: Wed, 27 Mar 2019 21:13:53 +0800 Subject: [PATCH 06/52] Fix --- cpp/src/plasma/client.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index 58aaee960b6..78faa8b243e 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -484,10 +484,9 @@ Status PlasmaClient::Impl::GetBuffers( // This client created the object but hasn't sealed it. If we call Get // with no timeout, we will deadlock, because this client won't be able to // call Seal. - if (timeout_ms != -1) { + if (timeout_ms == -1) { return Status::Invalid( - "Plasma client called get on an" - " unsealed object that it created"); + "Plasma client called get on an unsealed object that it created"); } ARROW_LOG(WARNING) << "Attempting to get an object that this client created but hasn't sealed."; @@ -782,8 +781,7 @@ Status PlasmaClient::Impl::Abort(const ObjectID& object_id) { auto object_entry = objects_in_use_.find(object_id); if (object_entry == objects_in_use_.end()) { return Status::Invalid( - "Plasma client called abort on " - "an object without a reference to it"); + "Plasma client called abort on an object without a reference to it"); } auto& entry = *object_entry->second; From 24338727979d78c6e18ea350d517e43fc9787fcc Mon Sep 17 00:00:00 2001 From: Marius Seritan Date: Fri, 28 Jun 2019 20:09:00 -0700 Subject: [PATCH 07/52] ARROW-5785: [Rust] Make the datafusion cli dependencies optional Dependent crates may not want the rustyline dependency, specially since the nightly support seems to be custom. Introduce a "cli" feature to allow consumers to not bring in the cli depedencies. Author: Marius Seritan Closes #4742 from winding-lines/master and squashes the following commits: 233158708 Make the datafusion cli optional Dependent crates may not want the rustyline dependency, specially since the nightly support seems to be custom. Introduce a "cli" feature to allow consumers to not bring in the cli depedencies. --- rust/datafusion/Cargo.toml | 9 +++++++-- rust/datafusion/src/bin/main.rs | 25 +++++++++++++++++++++++++ rust/datafusion/src/bin/repl.rs | 7 ++----- 3 files changed, 34 insertions(+), 7 deletions(-) create mode 100644 rust/datafusion/src/bin/main.rs diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index 1364f2440fb..4163f0bff99 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -36,7 +36,11 @@ path = "src/lib.rs" [[bin]] name = "datafusion-cli" -path = "src/bin/repl.rs" +path = "src/bin/main.rs" + +[features] +default = ["cli"] +cli = ["rustyline"] [dependencies] fnv = "1.0.3" @@ -47,12 +51,13 @@ serde_derive = "1.0.80" serde_json = "1.0.33" sqlparser = "0.2.0" clap = "2.33.0" -rustyline = "4.1.0" prettytable-rs = "0.8.0" +rustyline = {version = "4.1.0", optional = true} [dev-dependencies] criterion = "0.2.0" + [[bench]] name = "aggregate_query_sql" harness = false diff --git a/rust/datafusion/src/bin/main.rs b/rust/datafusion/src/bin/main.rs new file mode 100644 index 00000000000..deb5b796b2d --- /dev/null +++ b/rust/datafusion/src/bin/main.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Only bring in dependencies for the repl when the cli feature is enabled. +#[cfg(feature = "cli")] +mod repl; + +pub fn main() { + #[cfg(feature = "cli")] + repl::main() +} diff --git a/rust/datafusion/src/bin/repl.rs b/rust/datafusion/src/bin/repl.rs index 88a88943940..7ef042ed431 100644 --- a/rust/datafusion/src/bin/repl.rs +++ b/rust/datafusion/src/bin/repl.rs @@ -17,12 +17,9 @@ #![allow(bare_trait_objects)] -#[macro_use] -extern crate clap; - use arrow::array::*; use arrow::datatypes::{DataType, TimeUnit}; -use clap::{App, Arg}; +use clap::{crate_version, App, Arg}; use datafusion::error::{ExecutionError, Result}; use datafusion::execution::context::ExecutionContext; use datafusion::execution::relation::Relation; @@ -32,7 +29,7 @@ use std::cell::RefMut; use std::env; use std::path::Path; -fn main() { +pub fn main() { let matches = App::new("DataFusion") .version(crate_version!()) .about( From 4d25902f3795597e89b38347ebfe41ab510b9d2e Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 29 Jun 2019 15:09:44 +0900 Subject: [PATCH 08/52] ARROW-5787: [Release][Rust] Use local modules to verify RC Because we don't publish modules to crates.io yet. Author: Sutou Kouhei Closes #4747 from kou/release-verify-rust and squashes the following commits: c424af6b5 Use local modules to verify RC --- dev/release/verify-release-candidate.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 159d7dde7eb..4310536e574 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -458,6 +458,14 @@ test_rust() { # we are targeting Rust nightly for releases rustup default nightly + # use local modules because we don't publish modules to crates.io yet + sed \ + -i.bak \ + -E \ + -e 's/^arrow = "([^"]*)"/arrow = { version = "\1", path = "..\/arrow" }/g' \ + -e 's/^parquet = "([^"]*)"/parquet = { version = "\1", path = "..\/parquet" }/g' \ + */Cargo.toml + # raises on any warnings RUSTFLAGS="-D warnings" cargo build cargo test From dff73a4a8d2ec6ab4601e274226f328008693631 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 30 Jun 2019 14:40:38 +0900 Subject: [PATCH 09/52] ARROW-5793: [Release] Avoid duplicated known host SSH error in dev/release/03-binary.sh Author: Sutou Kouhei Closes #4753 from kou/release-upload-binary-avoid-known-host-duplication and squashes the following commits: 7fc0018bc Avoid duplicated known host SSH error in dev/release/03-binary.sh --- dev/release/03-binary.sh | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/dev/release/03-binary.sh b/dev/release/03-binary.sh index fa1119a3b99..74602cda749 100755 --- a/dev/release/03-binary.sh +++ b/dev/release/03-binary.sh @@ -88,13 +88,22 @@ docker_run() { docker_gpg_ssh() { local ssh_port=$1 shift - ssh \ - -o StrictHostKeyChecking=no \ - -i "${docker_ssh_key}" \ - -p ${ssh_port} \ - -R "/home/arrow/.gnupg/S.gpg-agent:${gpg_agent_extra_socket}" \ - arrow@127.0.0.1 \ - "$@" + local known_hosts_file=$(mktemp -t "arrow-binary-gpg-ssh-known-hosts.XXXXX") + local exit_code= + if ssh \ + -o StrictHostKeyChecking=no \ + -o UserKnownHostsFile=${known_hosts_file} \ + -i "${docker_ssh_key}" \ + -p ${ssh_port} \ + -R "/home/arrow/.gnupg/S.gpg-agent:${gpg_agent_extra_socket}" \ + arrow@127.0.0.1 \ + "$@"; then + exit_code=$?; + else + exit_code=$?; + fi + rm -f ${known_hosts_file} + return ${exit_code} } docker_run_gpg_ready() { From c2c58765575754ea3bce88f1f14e17302177ddd0 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 30 Jun 2019 16:19:00 +0900 Subject: [PATCH 10/52] ARROW-5794: [Release] Skip uploading already uploaded binaries Author: Sutou Kouhei Closes #4754 from kou/release-skip-uploaded-binary and squashes the following commits: e8cd528b4 Skip already uploaded file --- dev/release/03-binary.sh | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dev/release/03-binary.sh b/dev/release/03-binary.sh index 74602cda749..f3f418ca641 100755 --- a/dev/release/03-binary.sh +++ b/dev/release/03-binary.sh @@ -73,6 +73,8 @@ fi : ${BINTRAY_REPOSITORY:=apache/arrow} : ${SOURCE_BINTRAY_REPOSITORY:=${BINTRAY_REPOSITORY}} +BINTRAY_DOWNLOAD_URL_BASE=https://dl.bintray.com + docker_run() { docker \ run \ @@ -194,7 +196,7 @@ download_files() { --fail \ --location \ --output ${file} \ - https://dl.bintray.com/${SOURCE_BINTRAY_REPOSITORY}/${file} & + ${BINTRAY_DOWNLOAD_URL_BASE}/${SOURCE_BINTRAY_REPOSITORY}/${file} & done } @@ -244,6 +246,16 @@ sign_and_upload_file() { local local_path=$4 local upload_path=$5 + local sha256=$(shasum -a 256 ${local_path} | awk '{print $1}') + local download_path=/${BINTRAY_REPOSITORY}/${target}-rc/${upload_path} + if curl \ + --fail \ + --head \ + ${BINTRAY_DOWNLOAD_URL_BASE}${download_path} | \ + grep -q "^X-Checksum-Sha2: ${sha256}"; then + return 0 + fi + upload_file ${version} ${rc} ${target} ${local_path} ${upload_path} local suffix= From f0392dcf5d77ae499a66581edb3c0db3a07fd5a3 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 30 Jun 2019 16:23:30 +0900 Subject: [PATCH 11/52] ARROW-5795: [Release] Add missing waits on uploading binaries Author: Sutou Kouhei Closes #4755 from kou/release-add-missing-wait and squashes the following commits: 081c1aa10 Add missing waits on uploading binaries --- dev/release/03-binary.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dev/release/03-binary.sh b/dev/release/03-binary.sh index f3f418ca641..92cfaaf08b6 100755 --- a/dev/release/03-binary.sh +++ b/dev/release/03-binary.sh @@ -317,6 +317,7 @@ upload_deb() { for base_path in *; do upload_deb_file ${version} ${rc} ${distribution} ${code_name} ${base_path} & done + wait } upload_apt() { @@ -447,6 +448,7 @@ upload_rpm() { ${distribution_version} \ ${rpm_path} & done + wait } upload_yum() { @@ -516,6 +518,7 @@ upload_python() { ${base_path} \ ${version}-rc${rc}/${base_path} & done + wait } docker build -t ${docker_image_name} ${SOURCE_DIR}/binary From 53e813cb200983ce908252816b86e7df2797770e Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 30 Jun 2019 16:40:36 +0900 Subject: [PATCH 12/52] ARROW-5796: [Release][APT] Update expected package list libplasma-glib-doc and libgandiva-glib-doc are unavailable unexpetedly. It should be fixed in the next release. Author: Sutou Kouhei Closes #4756 from kou/release-verify-apt-update-packages and squashes the following commits: d95f70d33 Update expected package list --- dev/release/verify-apt.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/release/verify-apt.sh b/dev/release/verify-apt.sh index 6cca89c0000..8acb56042bc 100755 --- a/dev/release/verify-apt.sh +++ b/dev/release/verify-apt.sh @@ -125,12 +125,12 @@ if [ "${have_python}" = "yes" ]; then fi apt install -y -V libplasma-glib-dev=${deb_version} -apt install -y -V libplasma-glib-doc=${deb_version} -# apt install -y -V plasma-store-server=${deb_version} +# apt install -y -V libplasma-glib-doc=${deb_version} +apt install -y -V plasma-store-server=${deb_version} if [ "${have_gandiva}" = "yes" ]; then apt install -y -V libgandiva-glib-dev=${deb_version} - apt install -y -V libgandiva-glib-doc=${deb_version} + # apt install -y -V libgandiva-glib-doc=${deb_version} fi apt install -y -V libparquet-glib-dev=${deb_version} From c29462c9b2c1fe9e024d3999a99d80775197b6d9 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 30 Jun 2019 16:44:02 +0900 Subject: [PATCH 13/52] ARROW-5797: [Release][APT] Update supported distributions Author: Sutou Kouhei Closes #4758 from kou/release-verify-apt-update-distributions and squashes the following commits: 6a2af8dee Update supported distributions --- dev/release/verify-release-candidate.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 4310536e574..f694fb4efc0 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -167,7 +167,6 @@ test_binary() { test_apt() { for target in debian-stretch \ debian-buster \ - ubuntu-trusty \ ubuntu-xenial \ ubuntu-bionic \ ubuntu-cosmic \ From 10083ab9bc4b274dcf8df6cf469e9a835d513867 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Sun, 30 Jun 2019 14:04:46 -0500 Subject: [PATCH 14/52] ARROW-5609: [C++] Set CMP0068 CMake policy to avoid macOS warnings C++ and Python Tests pass locally so this seems to be ok for us. Author: Uwe L. Korn Closes #4752 from xhochy/ARROW-5609 and squashes the following commits: 6e087d67b ARROW-5609: Set CMP0068 CMake policy to avoid macOS warnings --- cpp/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c09bdebc68a..0e19b81ee0f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -70,6 +70,11 @@ if(POLICY CMP0054) cmake_policy(SET CMP0054 NEW) endif() +if(POLICY CMP0068) + # https://cmake.org/cmake/help/v3.9/policy/CMP0068.html + cmake_policy(SET CMP0068 NEW) +endif() + # don't ignore _ROOT variables in find_package if(POLICY CMP0074) # https://cmake.org/cmake/help/v3.12/policy/CMP0074.html From 91b4cbce0c3eb4d82d8f5a33b878d1bfba7651a3 Mon Sep 17 00:00:00 2001 From: tianchen Date: Sun, 30 Jun 2019 20:09:08 -0700 Subject: [PATCH 15/52] ARROW-5726: [Java] Implement a common interface for int vectors Related to [ARROW-5726](https://issues.apache.org/jira/browse/ARROW-5726). Now in DictionaryEncoder#encode it use reflection to pull out the set method and then set values. Set values by reflection is not efficient and code structure is not elegant such as Method setter = null; for (Class c : Arrays.asList(int.class, long.class)) { try { setter = indices.getClass().getMethod("setSafe", int.class, c); break; } catch (NoSuchMethodException e) { // ignore } } Implement a common interface for int vectors to directly get set method and set values seems a good choice. Author: tianchen Closes #4698 from tianchen92/ARROW-5726 and squashes the following commits: 37ec9cc48 resolve comments 021928254 fix overflow b97cbef81 fix a6f351f32 resolve comments 5e58c5244 Implement a common interface for int vectors --- .../apache/arrow/vector/BaseIntVector.java | 29 +++++++++++ .../org/apache/arrow/vector/BigIntVector.java | 7 ++- .../org/apache/arrow/vector/IntVector.java | 7 ++- .../apache/arrow/vector/SmallIntVector.java | 9 +++- .../apache/arrow/vector/TinyIntVector.java | 9 +++- .../org/apache/arrow/vector/UInt1Vector.java | 13 ++++- .../org/apache/arrow/vector/UInt2Vector.java | 9 +++- .../org/apache/arrow/vector/UInt4Vector.java | 7 ++- .../org/apache/arrow/vector/UInt8Vector.java | 7 ++- .../vector/dictionary/DictionaryEncoder.java | 50 ++++++------------- 10 files changed, 104 insertions(+), 43 deletions(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java new file mode 100644 index 00000000000..74387de9486 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +/** + * Interface for all int type vectors. + */ +public interface BaseIntVector extends ValueVector { + + /** + * set the encoded value from a {@link org.apache.arrow.vector.dictionary.Dictionary}. + */ + void setEncodedValue(int index, int value); +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java index 65ce53e2581..416ffd53fd3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java @@ -35,7 +35,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class BigIntVector extends BaseFixedWidthVector { +public class BigIntVector extends BaseFixedWidthVector implements BaseIntVector { public static final byte TYPE_WIDTH = 8; private final FieldReader reader; @@ -339,6 +339,11 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((BigIntVector) to); } + @Override + public void setEncodedValue(int index, int value) { + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { BigIntVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java index 3a8207f0abc..5255d87a2ed 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java @@ -35,7 +35,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class IntVector extends BaseFixedWidthVector { +public class IntVector extends BaseFixedWidthVector implements BaseIntVector { public static final byte TYPE_WIDTH = 4; private final FieldReader reader; @@ -343,6 +343,11 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((IntVector) to); } + @Override + public void setEncodedValue(int index, int value) { + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { IntVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java index dddc46fef2b..2d3f78f9766 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.SmallIntReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableSmallIntHolder; @@ -35,7 +36,7 @@ * short values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class SmallIntVector extends BaseFixedWidthVector { +public class SmallIntVector extends BaseFixedWidthVector implements BaseIntVector { public static final byte TYPE_WIDTH = 2; private final FieldReader reader; @@ -370,6 +371,12 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((SmallIntVector) to); } + @Override + public void setEncodedValue(int index, int value) { + Preconditions.checkArgument(value <= Short.MAX_VALUE, "value is overflow: %s", value); + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { SmallIntVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java index df40b6e57cc..66f7ca35d01 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.TinyIntReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableTinyIntHolder; @@ -35,7 +36,7 @@ * byte values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class TinyIntVector extends BaseFixedWidthVector { +public class TinyIntVector extends BaseFixedWidthVector implements BaseIntVector { public static final byte TYPE_WIDTH = 1; private final FieldReader reader; @@ -370,6 +371,12 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((TinyIntVector) to); } + @Override + public void setEncodedValue(int index, int value) { + Preconditions.checkArgument(value <= Byte.MAX_VALUE, "value is overflow: %s", value); + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { TinyIntVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java index c5133344fe8..85d48ad9e37 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.UInt1ReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableUInt1Holder; @@ -35,7 +36,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class UInt1Vector extends BaseFixedWidthVector { +public class UInt1Vector extends BaseFixedWidthVector implements BaseIntVector { private static final byte TYPE_WIDTH = 1; private final FieldReader reader; @@ -150,7 +151,7 @@ public void copyFrom(int fromIndex, int thisIndex, UInt1Vector from) { } /** - * Identical to {@link #copyFrom()} but reallocates buffer if index is larger + * Identical to {@link #copyFrom(int, int, UInt1Vector)} but reallocates buffer if index is larger * than capacity. */ public void copyFromSafe(int fromIndex, int thisIndex, UInt1Vector from) { @@ -329,6 +330,14 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((UInt1Vector) to); } + @Override + public void setEncodedValue(int index, int value) { + Preconditions.checkArgument(value <= 0xFF, "value is overflow: %s", value); + this.setSafe(index, value); + } + + + private class TransferImpl implements TransferPair { UInt1Vector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java index 631050d57a2..dbea9f82b6e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.UInt2ReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableUInt2Holder; @@ -35,7 +36,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class UInt2Vector extends BaseFixedWidthVector { +public class UInt2Vector extends BaseFixedWidthVector implements BaseIntVector { private static final byte TYPE_WIDTH = 2; private final FieldReader reader; @@ -308,6 +309,12 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((UInt2Vector) to); } + @Override + public void setEncodedValue(int index, int value) { + Preconditions.checkArgument(value <= Character.MAX_VALUE, "value is overflow: %s", value); + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { UInt2Vector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java index 84e6b8f3788..b2eadc2a22b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java @@ -35,7 +35,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class UInt4Vector extends BaseFixedWidthVector { +public class UInt4Vector extends BaseFixedWidthVector implements BaseIntVector { private static final byte TYPE_WIDTH = 4; private final FieldReader reader; @@ -301,6 +301,11 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((UInt4Vector) to); } + @Override + public void setEncodedValue(int index, int value) { + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { UInt4Vector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java index 0f8da381ee5..a1b3bdabdee 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java @@ -37,7 +37,7 @@ * integer values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class UInt8Vector extends BaseFixedWidthVector { +public class UInt8Vector extends BaseFixedWidthVector implements BaseIntVector { private static final byte TYPE_WIDTH = 8; private final FieldReader reader; @@ -302,6 +302,11 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((UInt8Vector) to); } + @Override + public void setEncodedValue(int index, int value) { + this.setSafe(index, value); + } + private class TransferImpl implements TransferPair { UInt8Vector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 1c2a0aced17..698191c2ca2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -17,12 +17,10 @@ package org.apache.arrow.vector.dictionary; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.types.Types.MinorType; @@ -61,43 +59,27 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { Field indexField = new Field(valueField.getName(), indexFieldType, null); // vector to hold our indices (dictionary encoded values) - FieldVector indices = indexField.createVector(vector.getAllocator()); - - // use reflection to pull out the set method - // TODO implement a common interface for int vectors - Method setter = null; - for (Class c : Arrays.asList(int.class, long.class)) { - try { - setter = indices.getClass().getMethod("setSafe", int.class, c); - break; - } catch (NoSuchMethodException e) { - // ignore - } + FieldVector createdVector = indexField.createVector(vector.getAllocator()); + if (! (createdVector instanceof BaseIntVector)) { + throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + + createdVector.getClass()); } - if (setter == null) { - throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + indices.getClass()); - } - - int count = vector.getValueCount(); + BaseIntVector indices = (BaseIntVector) createdVector; indices.allocateNew(); - try { - for (int i = 0; i < count; i++) { - Object value = vector.getObject(i); - if (value != null) { // if it's null leave it null - // note: this may fail if value was not included in the dictionary - Object encoded = lookUps.get(value); - if (encoded == null) { - throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value); - } - setter.invoke(indices, i, encoded); + int count = vector.getValueCount(); + + for (int i = 0; i < count; i++) { + Object value = vector.getObject(i); + if (value != null) { // if it's null leave it null + // note: this may fail if value was not included in the dictionary + Integer encoded = lookUps.get(value); + if (encoded == null) { + throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value); } + indices.setEncodedValue(i, encoded); } - } catch (IllegalAccessException e) { - throw new RuntimeException("IllegalAccessException invoking vector mutator set():", e); - } catch (InvocationTargetException e) { - throw new RuntimeException("InvocationTargetException invoking vector mutator set():", e.getCause()); } indices.setValueCount(count); From 175ad654d45f4c014749b661ae3971ffe8f0512b Mon Sep 17 00:00:00 2001 From: Pindikura Ravindra Date: Mon, 1 Jul 2019 15:11:26 +0530 Subject: [PATCH 16/52] ARROW-3459: [C++][Gandiva] support for string o/p - If the output vectors aren't provided, allow resizable data buffers. - If the output vectors are provided, assert that the data buffer is resizeable. - use a cpp function to write to string-like o/p buffers, this checks for capacity and updates the offset vector. Author: Pindikura Ravindra Closes #4760 from pravindra/varlen and squashes the following commits: 0068b6a07 ARROW-3459: support for string o/p --- cpp/src/gandiva/annotator.cc | 26 ++++-- cpp/src/gandiva/annotator.h | 5 +- cpp/src/gandiva/annotator_test.cc | 5 +- cpp/src/gandiva/expr_validator.cc | 6 ++ cpp/src/gandiva/field_descriptor.h | 11 ++- cpp/src/gandiva/gdv_function_stubs.cc | 37 +++++++++ cpp/src/gandiva/jni/jni_common.cc | 49 +++++++++-- cpp/src/gandiva/llvm_generator.cc | 20 +++++ cpp/src/gandiva/llvm_generator.h | 3 + cpp/src/gandiva/projector.cc | 81 ++++++++++++++----- cpp/src/gandiva/projector.h | 4 +- cpp/src/gandiva/tests/utf8_test.cc | 28 +++++-- .../arrow/gandiva/evaluator/Projector.java | 15 +++- .../gandiva/evaluator/ProjectorTest.java | 74 +++++++++++++++++ 14 files changed, 314 insertions(+), 50 deletions(-) diff --git a/cpp/src/gandiva/annotator.cc b/cpp/src/gandiva/annotator.cc index 754d70e0e04..0eab915d351 100644 --- a/cpp/src/gandiva/annotator.cc +++ b/cpp/src/gandiva/annotator.cc @@ -31,30 +31,35 @@ FieldDescriptorPtr Annotator::CheckAndAddInputFieldDescriptor(FieldPtr field) { return found->second; } - auto desc = MakeDesc(field); + auto desc = MakeDesc(field, false /*is_output*/); in_name_to_desc_[field->name()] = desc; return desc; } FieldDescriptorPtr Annotator::AddOutputFieldDescriptor(FieldPtr field) { - auto desc = MakeDesc(field); + auto desc = MakeDesc(field, true /*is_output*/); out_descs_.push_back(desc); return desc; } -FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field) { +FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) { int data_idx = buffer_count_++; int validity_idx = buffer_count_++; int offsets_idx = FieldDescriptor::kInvalidIdx; if (arrow::is_binary_like(field->type()->id())) { offsets_idx = buffer_count_++; } - return std::make_shared(field, data_idx, validity_idx, offsets_idx); + int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; + if (is_output) { + data_buffer_ptr_idx = buffer_count_++; + } + return std::make_shared(field, data_idx, validity_idx, offsets_idx, + data_buffer_ptr_idx); } void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, const arrow::ArrayData& array_data, - EvalBatch* eval_batch) { + EvalBatch* eval_batch, bool is_output) { int buffer_idx = 0; // The validity buffer is optional. Use nullptr if it does not have one. @@ -74,7 +79,12 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); eval_batch->SetBuffer(desc.data_idx(), data_buf); - ++buffer_idx; + if (is_output) { + // pass in the Buffer object for output data buffers. Can be used for resizing. + uint8_t* data_buf_ptr = + reinterpret_cast(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr); + } } EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, @@ -92,14 +102,14 @@ EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, } PrepareBuffersForField(*(found->second), *(record_batch.column(i))->data(), - eval_batch.get()); + eval_batch.get(), false /*is_output*/); } // Fill in the entries for the output fields. int idx = 0; for (auto& arraydata : out_vector) { const FieldDescriptorPtr& desc = out_descs_.at(idx); - PrepareBuffersForField(*desc, *arraydata, eval_batch.get()); + PrepareBuffersForField(*desc, *arraydata, eval_batch.get(), true /*is_output*/); ++idx; } return eval_batch; diff --git a/cpp/src/gandiva/annotator.h b/cpp/src/gandiva/annotator.h index c0ddc024635..dcf665c04a5 100644 --- a/cpp/src/gandiva/annotator.h +++ b/cpp/src/gandiva/annotator.h @@ -54,12 +54,13 @@ class GANDIVA_EXPORT Annotator { private: /// Annotate a field and return the descriptor. - FieldDescriptorPtr MakeDesc(FieldPtr field); + FieldDescriptorPtr MakeDesc(FieldPtr field, bool is_output); /// Populate eval_batch by extracting the raw buffers from the arrow array, whose /// contents are represent by the annotated descriptor 'desc'. void PrepareBuffersForField(const FieldDescriptor& desc, - const arrow::ArrayData& array_data, EvalBatch* eval_batch); + const arrow::ArrayData& array_data, EvalBatch* eval_batch, + bool is_output); /// The list of input/output buffers (includes bitmap buffers, value buffers and /// offset buffers). diff --git a/cpp/src/gandiva/annotator_test.cc b/cpp/src/gandiva/annotator_test.cc index dabf4e65990..cd829f75c51 100644 --- a/cpp/src/gandiva/annotator_test.cc +++ b/cpp/src/gandiva/annotator_test.cc @@ -73,6 +73,7 @@ TEST_F(TestAnnotator, TestAdd) { EXPECT_EQ(desc_sum->field(), field_sum); EXPECT_EQ(desc_sum->data_idx(), 4); EXPECT_EQ(desc_sum->validity_idx(), 5); + EXPECT_EQ(desc_sum->data_buffer_ptr_idx(), 6); // prepare record batch int num_records = 100; @@ -85,7 +86,7 @@ TEST_F(TestAnnotator, TestAdd) { auto arrow_sum = MakeInt32Array(num_records); EvalBatchPtr batch = annotator.PrepareEvalBatch(*record_batch, {arrow_sum->data()}); - EXPECT_EQ(batch->GetNumBuffers(), 6); + EXPECT_EQ(batch->GetNumBuffers(), 7); auto buffers = batch->GetBufferArray(); EXPECT_EQ(buffers[desc_a->validity_idx()], arrow_v0->data()->buffers.at(0)->data()); @@ -94,6 +95,8 @@ TEST_F(TestAnnotator, TestAdd) { EXPECT_EQ(buffers[desc_b->data_idx()], arrow_v1->data()->buffers.at(1)->data()); EXPECT_EQ(buffers[desc_sum->validity_idx()], arrow_sum->data()->buffers.at(0)->data()); EXPECT_EQ(buffers[desc_sum->data_idx()], arrow_sum->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_sum->data_buffer_ptr_idx()], + reinterpret_cast(arrow_sum->data()->buffers.at(1).get())); auto bitmaps = batch->GetLocalBitMapArray(); EXPECT_EQ(bitmaps, nullptr); diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index 923841c0bf9..bce43d53c8c 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -89,6 +89,12 @@ Status ExprValidator::Visit(const IfNode& node) { auto then_node_ret_type = node.then_node()->return_type(); auto else_node_ret_type = node.else_node()->return_type(); + // condition must be of boolean type. + ARROW_RETURN_IF( + !node.condition()->return_type()->Equals(arrow::boolean()), + Status::ExpressionValidationError("condition must be of boolean type, found type ", + node.condition()->return_type()->ToString())); + // Then-branch return type must match. ARROW_RETURN_IF(!if_node_ret_type->Equals(*then_node_ret_type), Status::ExpressionValidationError( diff --git a/cpp/src/gandiva/field_descriptor.h b/cpp/src/gandiva/field_descriptor.h index 70583b0405b..d931f378ff2 100644 --- a/cpp/src/gandiva/field_descriptor.h +++ b/cpp/src/gandiva/field_descriptor.h @@ -31,11 +31,12 @@ class FieldDescriptor { static const int kInvalidIdx = -1; FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, - int offsets_idx = kInvalidIdx) + int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx) : field_(field), data_idx_(data_idx), validity_idx_(validity_idx), - offsets_idx_(offsets_idx) {} + offsets_idx_(offsets_idx), + data_buffer_ptr_idx_(data_buffer_ptr_idx) {} /// Index of validity array in the array-of-buffers int validity_idx() const { return validity_idx_; } @@ -46,6 +47,9 @@ class FieldDescriptor { /// Index of offsets array in the array-of-buffers int offsets_idx() const { return offsets_idx_; } + /// Index of data buffer pointer in the array-of-buffers + int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; } + FieldPtr field() const { return field_; } const std::string& Name() const { return field_->name(); } @@ -53,11 +57,14 @@ class FieldDescriptor { bool HasOffsetsIdx() const { return offsets_idx_ != kInvalidIdx; } + bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; } + private: FieldPtr field_; int data_idx_; int validity_idx_; int offsets_idx_; + int data_buffer_ptr_idx_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 5eacdf769d0..570e0263f9a 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -72,6 +72,31 @@ bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, reinterpret_cast*>(ptr); return holder->HasValue(std::string(data, data_len)); } + +int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr, + int32_t* offsets, int64_t slot, + const char* entry_buf, int32_t entry_len) { + auto buffer = reinterpret_cast(data_ptr); + int32_t offset = static_cast(buffer->size()); + + // This also sets the size in the buffer. + auto status = buffer->Resize(offset + entry_len, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new entry. + memcpy(buffer->mutable_data() + offset, entry_buf, entry_len); + + // update offsets buffer. + offsets[slot] = offset; + offsets[slot + 1] = offset + entry_len; + return 0; +} } namespace gandiva { @@ -135,6 +160,18 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8", types->i1_type() /*return_type*/, args, reinterpret_cast(gdv_fn_in_expr_lookup_utf8)); + + // gdv_fn_populate_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_type()}; // int32_t entry__len + + engine->AddGlobalMappingForFunc("gdv_fn_populate_varlen_vector", + types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_varlen_vector)); } } // namespace gandiva diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc index 09d27398b67..2ff4bc9619a 100644 --- a/cpp/src/gandiva/jni/jni_common.cc +++ b/cpp/src/gandiva/jni/jni_common.cc @@ -632,6 +632,32 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build return module_id; } +/// +/// \brief Resizable buffer which resizes by doing a callback into java. +/// +class JavaResizableBuffer : public arrow::ResizableBuffer { + public: + JavaResizableBuffer(uint8_t* buffer, int32_t len) : ResizableBuffer(buffer, len) { + size_ = 0; + } + + Status Resize(const int64_t new_size, bool shrink_to_fit) override { + if (shrink_to_fit == true) { + return Status::NotImplemented("shrink not implemented"); + } else if (new_size < capacity()) { + size_ = new_size; + return Status::OK(); + } else { + // TODO: callback into java to re-alloc the buffer. + return Status::NotImplemented("buffer expand not implemented"); + } + } + + Status Reserve(const int64_t new_capacity) override { + return Status::NotImplemented("reserve not implemented"); + } +}; + #define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \ if (idx >= len) { \ status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \ @@ -710,20 +736,31 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( int buf_idx = 0; int sz_idx = 0; for (FieldPtr field : ret_types) { + std::vector> buffers; + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); uint8_t* validity_buf = reinterpret_cast(out_bufs[buf_idx++]); jlong bitmap_sz = out_sizes[sz_idx++]; - std::shared_ptr bitmap_buf = - std::make_shared(validity_buf, bitmap_sz); + buffers.push_back(std::make_shared(validity_buf, bitmap_sz)); + + if (arrow::is_binary_like(field->type()->id())) { + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* offsets_buf = reinterpret_cast(out_bufs[buf_idx++]); + jlong offsets_sz = out_sizes[sz_idx++]; + buffers.push_back( + std::make_shared(offsets_buf, offsets_sz)); + } CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); uint8_t* value_buf = reinterpret_cast(out_bufs[buf_idx++]); jlong data_sz = out_sizes[sz_idx++]; - std::shared_ptr data_buf = - std::make_shared(value_buf, data_sz); + if (arrow::is_binary_like(field->type()->id())) { + buffers.push_back(std::make_shared(value_buf, data_sz)); + } else { + buffers.push_back(std::make_shared(value_buf, data_sz)); + } - auto array_data = - arrow::ArrayData::Make(field->type(), output_row_count, {bitmap_buf, data_buf}); + auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers); output.push_back(array_data); } status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output); diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 1d5946dec80..f5407556ee9 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -155,6 +155,14 @@ llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray"); } +/// Get reference to data array at specified index in the args list. +llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr"); +} + /// Get reference to data array at specified index in the args list. llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx, FieldPtr field) { @@ -293,6 +301,10 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out builder->SetInsertPoint(loop_entry); llvm::Value* output_ref = GetDataReference(arg_addrs, output->data_idx(), output->field()); + llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference( + arg_addrs, output->data_buffer_ptr_idx(), output->field()); + llvm::Value* output_offset_ref = + GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field()); // Loop body builder->SetInsertPoint(loop_body); @@ -323,6 +335,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out // save the value in the output vector. builder->SetInsertPoint(loop_body_tail); + auto output_type_id = output->Type()->id(); if (output_type_id == arrow::Type::BOOL) { SetPackedBitValue(output_ref, loop_var, output_value->data()); @@ -330,6 +343,13 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out output_type_id == arrow::Type::DECIMAL) { llvm::Value* slot_offset = builder->CreateGEP(output_ref, loop_var); builder->CreateStore(output_value->data(), slot_offset); + } else if (arrow::is_binary_like(output_type_id)) { + // Var-len output. Make a function call to populate the data. + // if there is an error, the fn sets it in the context. And, will be returned at the + // end of this row batch. + AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, + output_value->data(), output_value->length()}); } else { return Status::NotImplemented("output type ", output->Type()->ToString(), " not supported"); diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index a68f0d518e9..122eaf6243a 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -180,6 +180,9 @@ class GANDIVA_EXPORT LLVMGenerator { /// Generate code to load the vector at specified index and cast it as offsets array. llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as buffer pointer. + llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code for the value array of one expression. Status CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr output, int suffix_idx, llvm::Function** fn, diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index 6493fd4d908..d8c6a80b52f 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -168,28 +168,50 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, return Status::OK(); } -// TODO : handle variable-len vectors +// TODO : handle complex vectors (list/map/..) Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, arrow::MemoryPool* pool, ArrayDataPtr* array_data) { - const auto* fw_type = dynamic_cast(type.get()); - ARROW_RETURN_IF(fw_type == nullptr, - Status::Invalid("Unsupported output data type ", type)); - - std::shared_ptr null_bitmap; - int64_t bitmap_bytes = arrow::BitUtil::BytesForBits(num_records); - ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, bitmap_bytes, &null_bitmap)); + arrow::Status astatus; + std::vector> buffers; + + // The output vector always has a null bitmap. + std::shared_ptr bitmap_buffer; + int64_t size = arrow::BitUtil::BytesForBits(num_records); + ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, size, &bitmap_buffer)); + buffers.push_back(bitmap_buffer); + + // String/Binary vectors have an offsets array. + auto type_id = type->id(); + if (arrow::is_binary_like(type_id)) { + std::shared_ptr offsets_buffer; + auto offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); + + ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, offsets_len, &offsets_buffer)); + buffers.push_back(offsets_buffer); + } - std::shared_ptr data; - int64_t data_len = arrow::BitUtil::BytesForBits(num_records * fw_type->bit_width()); - ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, data_len, &data)); + // The output vector always has a data array. + int64_t data_len; + std::shared_ptr data_buffer; + if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { + const auto& fw_type = dynamic_cast(*type); + data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + } else if (arrow::is_binary_like(type_id)) { + // we don't know the expected size for varlen output vectors. + data_len = 0; + } else { + return Status::Invalid("Unsupported output data type " + type->ToString()); + } + ARROW_RETURN_NOT_OK(arrow::AllocateResizableBuffer(pool, data_len, &data_buffer)); // This is not strictly required but valgrind gets confused and detects this // as uninitialized memory access. See arrow::util::SetBitTo(). if (type->id() == arrow::Type::BOOL) { - memset(data->mutable_data(), 0, data_len); + memset(data_buffer->mutable_data(), 0, data_len); } + buffers.push_back(data_buffer); - *array_data = arrow::ArrayData::Make(type, num_records, {null_bitmap, data}); + *array_data = arrow::ArrayData::Make(type, num_records, buffers); return Status::OK(); } @@ -213,13 +235,32 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, ARROW_RETURN_IF(bitmap_len < min_bitmap_len, Status::Invalid("Bitmap buffer too small for ", field.name())); - // verify size of data buffer. - // TODO : handle variable-len vectors - const auto& fw_type = dynamic_cast(*field.type()); - int64_t min_data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); - int64_t data_len = array_data.buffers[1]->capacity(); - ARROW_RETURN_IF(data_len < min_data_len, - Status::Invalid("Data buffer too small for ", field.name())); + auto type_id = field.type()->id(); + if (arrow::is_binary_like(type_id)) { + // validate size of offsets buffer. + int64_t min_offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); + int64_t offsets_len = array_data.buffers[1]->capacity(); + ARROW_RETURN_IF( + offsets_len < min_offsets_len, + Status::Invalid("offsets buffer too small for ", field.name(), + " minimum required ", min_offsets_len, " actual ", offsets_len)); + + // check that it's resizable. + auto resizable = dynamic_cast(array_data.buffers[2].get()); + ARROW_RETURN_IF( + resizable == nullptr, + Status::Invalid("data buffer for varlen output vectors must be resizable")); + } else if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { + // verify size of data buffer. + const auto& fw_type = dynamic_cast(*field.type()); + int64_t min_data_len = + arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + int64_t data_len = array_data.buffers[1]->capacity(); + ARROW_RETURN_IF(data_len < min_data_len, + Status::Invalid("Data buffer too small for ", field.name())); + } else { + return Status::Invalid("Unsupported output data type " + field.type()->ToString()); + } return Status::OK(); } diff --git a/cpp/src/gandiva/projector.h b/cpp/src/gandiva/projector.h index 0aa09dfe3bd..ff2fbc7b38d 100644 --- a/cpp/src/gandiva/projector.h +++ b/cpp/src/gandiva/projector.h @@ -122,8 +122,8 @@ class GANDIVA_EXPORT Projector { const FieldVector& output_fields, std::shared_ptr); /// Allocate an ArrowData of length 'length'. - Status AllocArrayData(const DataTypePtr& type, int64_t length, arrow::MemoryPool* pool, - ArrayDataPtr* array_data); + Status AllocArrayData(const DataTypePtr& type, int64_t num_records, + arrow::MemoryPool* pool, ArrayDataPtr* array_data); /// Validate that the ArrayData has sufficient capacity to accomodate 'num_records'. Status ValidateArrayDataCapacity(const arrow::ArrayData& array_data, diff --git a/cpp/src/gandiva/tests/utf8_test.cc b/cpp/src/gandiva/tests/utf8_test.cc index ea9a76ce805..103992d23fe 100644 --- a/cpp/src/gandiva/tests/utf8_test.cc +++ b/cpp/src/gandiva/tests/utf8_test.cc @@ -506,19 +506,37 @@ TEST_F(TestUtf8, TestIsNull) { TEST_F(TestUtf8, TestVarlenOutput) { // schema for input fields - auto field_a = field("a", utf8()); + auto field_a = field("a", boolean()); auto schema = arrow::schema({field_a}); // build expressions. - auto expr = TreeExprBuilder::MakeExpression(TreeExprBuilder::MakeField(field_a), - field("res", utf8())); + // if (a) literal_hi else literal_bye + auto if_node = TreeExprBuilder::MakeIf( + TreeExprBuilder::MakeField(field_a), TreeExprBuilder::MakeStringLiteral("hi"), + TreeExprBuilder::MakeStringLiteral("bye"), utf8()); + auto expr = TreeExprBuilder::MakeExpression(if_node, field("res", utf8())); // Build a projector for the expressions. std::shared_ptr projector; // assert that it fails gracefully. - ASSERT_RAISES(NotImplemented, - Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_in = + MakeArrowArrayBool({true, false, false, false}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_in}); + + // Evaluate expression + arrow::ArrayVector outputs; + ASSERT_OK(projector->Evaluate(*in_batch, pool_, &outputs)); + + // expected output + auto exp = MakeArrowArrayUtf8({"hi", "bye", "bye", "bye"}, {true, true, true, true}); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); } TEST_F(TestUtf8, TestCastVarChar) { diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java index ae93fba5991..93657e6f78f 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java @@ -29,6 +29,7 @@ import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; import org.apache.arrow.vector.FixedWidthVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VariableWidthVector; import org.apache.arrow.vector.ipc.message.ArrowBuffer; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; @@ -235,16 +236,22 @@ private void evaluate(int numRows, List buffers, List buf bufSizes[idx++] = bufLayout.getSize(); } - long[] outAddrs = new long[2 * outColumns.size()]; - long[] outSizes = new long[2 * outColumns.size()]; + long[] outAddrs = new long[3 * outColumns.size()]; + long[] outSizes = new long[3 * outColumns.size()]; idx = 0; for (ValueVector valueVector : outColumns) { - if (!(valueVector instanceof FixedWidthVector)) { - throw new UnsupportedTypeException("Unsupported value vector type"); + boolean isFixedWith = valueVector instanceof FixedWidthVector; + boolean isVarWidth = valueVector instanceof VariableWidthVector; + if (!isFixedWith && !isVarWidth) { + throw new UnsupportedTypeException("Unsupported value vector type " + valueVector.getField().getFieldType()); } outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress(); outSizes[idx++] = valueVector.getValidityBuffer().capacity(); + if (isVarWidth) { + outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getOffsetBuffer().capacity(); + } outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); outSizes[idx++] = valueVector.getDataBuffer().capacity(); diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 62a12710cc7..2fd80910db8 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -38,6 +38,7 @@ import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.DateUnit; @@ -48,7 +49,9 @@ import org.junit.Assert; import org.junit.Ignore; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -60,6 +63,9 @@ public class ProjectorTest extends BaseEvaluatorTest { private Charset utf8Charset = Charset.forName("UTF-8"); private Charset utf16Charset = Charset.forName("UTF-16"); + @Rule + public ExpectedException thrown = ExpectedException.none(); + List varBufs(String[] strings, Charset charset) { ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4); ArrowBuf dataBuffer = allocator.buffer(strings.length * 8); @@ -516,6 +522,74 @@ public void testStringFields() throws GandivaException { eval.close(); } + @Test + public void testStringOutput() throws GandivaException { + /* + * if (x >= 0) "hi" else "bye" + */ + + Field x = Field.nullable("x", new ArrowType.Int(32, true)); + + ArrowType retType = new ArrowType.Utf8(); + + TreeNode ifHiBye = TreeBuilder.makeIf( + TreeBuilder.makeFunction( + "greater_than_or_equal_to", + Lists.newArrayList( + TreeBuilder.makeField(x), + TreeBuilder.makeLiteral(0) + ), + boolType), + TreeBuilder.makeStringLiteral("hi"), + TreeBuilder.makeStringLiteral("bye"), + retType); + + ExpressionTree expr = TreeBuilder.makeExpression(ifHiBye, Field.nullable("res", retType)); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + // fill up input record batch + int numRows = 4; + byte[] validity = new byte[]{(byte) 255, 0}; + int[] xValues = new int[]{10, -10, 20, -20}; + String[] expected = new String[]{"hi", "bye", "hi", "bye"}; + ArrowBuf validityX = buf(validity); + ArrowBuf dataX = intBuf(xValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList( validityX, dataX)); + + // allocate data for output vector. + VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + outVector.allocateNew(64, numRows); + + + // evaluate expression + List output = new ArrayList<>(); + output.add(outVector); + eval.evaluate(batch, output); + + // match expected output. + for (int i = 0; i < numRows; i++) { + assertFalse(outVector.isNull(i)); + assertEquals(expected[i], new String(outVector.get(i))); + } + + // test with insufficient data buffer. + try { + outVector.allocateNew(4, numRows); + thrown.expect(GandivaException.class); + thrown.expectMessage("expand not implemented"); + eval.evaluate(batch, output); + } finally { + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + } + @Test public void testRegex() throws GandivaException { /* From 8192902f00a7ab74ca93e6f2d9316fd32d18923f Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Mon, 1 Jul 2019 13:42:14 -0500 Subject: [PATCH 17/52] ARROW-5791: [C++] Fix infinite loop with more the 32768 columns. But really 32768 columns should be enough for anyone :) Author: Micah Kornfield Closes #4762 from emkornfield/csv and squashes the following commits: ab0504c16 lower number of columns in test to satisfy ming 8f53a8a58 remove test acfe2d894 remove cap, make min rows_in_chunk 512 08ddc2238 remove floor duplication 211472a12 powers of 2 are better b91a9e177 ARROW-5791: Fix infinite loop with more the 32768 columns. Cap max columns --- cpp/src/arrow/csv/parser-test.cc | 25 +++++++++++++++++++++++++ cpp/src/arrow/csv/parser.cc | 7 +++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/csv/parser-test.cc b/cpp/src/arrow/csv/parser-test.cc index 36552309b27..d1790b23da1 100644 --- a/cpp/src/arrow/csv/parser-test.cc +++ b/cpp/src/arrow/csv/parser-test.cc @@ -439,6 +439,31 @@ TEST(BlockParser, Escaping) { } } +// Generate test data with the given number of columns. +std::string MakeLotsOfCsvColumns(int32_t num_columns) { + std::string values, header; + header.reserve(num_columns * 10); + values.reserve(num_columns * 10); + for (int x = 0; x < num_columns; x++) { + if (x != 0) { + header += ","; + values += ","; + } + header += "c" + std::to_string(x); + values += std::to_string(x); + } + + header += "\n"; + values += "\n"; + return MakeCSVData({header, values}); +} + +TEST(BlockParser, LotsOfColumns) { + auto options = ParseOptions::Defaults(); + BlockParser parser(options); + AssertParseOk(parser, MakeLotsOfCsvColumns(1024 * 100)); +} + TEST(BlockParser, QuotedEscape) { auto options = ParseOptions::Defaults(); options.escaping = true; diff --git a/cpp/src/arrow/csv/parser.cc b/cpp/src/arrow/csv/parser.cc index a7ca71c9fd7..89c3f4cb168 100644 --- a/cpp/src/arrow/csv/parser.cc +++ b/cpp/src/arrow/csv/parser.cc @@ -397,16 +397,19 @@ Status BlockParser::DoParseSpecialized(const char* start, uint32_t size, bool is return ParseError("Empty CSV file or block: cannot infer number of columns"); } } + while (!finished_parsing && data < data_end && num_rows_ < max_num_rows_) { // We know the number of columns, so can presize a values array for // a given number of rows DCHECK_GE(num_cols_, 0); int32_t rows_in_chunk; + constexpr int32_t kTargetChunkSize = 32768; if (num_cols_ > 0) { - rows_in_chunk = std::min(32768 / num_cols_, max_num_rows_ - num_rows_); + rows_in_chunk = std::min(std::max(kTargetChunkSize / num_cols_, 512), + max_num_rows_ - num_rows_); } else { - rows_in_chunk = std::min(32768, max_num_rows_ - num_rows_); + rows_in_chunk = std::min(kTargetChunkSize, max_num_rows_ - num_rows_); } PresizedValuesWriter values_writer(pool_, rows_in_chunk, num_cols_); From 3b6b271e2f6c2540105a06c1fa5fb17381e5d1ba Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 1 Jul 2019 12:59:28 -0700 Subject: [PATCH 18/52] ARROW-5792: [Rust] Add TypeVisitor for parquet type. This trait is helpful when dealing with parquet type. Author: Renjie Liu Closes #4766 from liurenjie1024/arrow-5792 and squashes the following commits: 02d59a5f Fix build failure 908f1ec1 Add TypeVisitor. --- rust/parquet/src/schema/mod.rs | 1 + rust/parquet/src/schema/visitor.rs | 240 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 rust/parquet/src/schema/visitor.rs diff --git a/rust/parquet/src/schema/mod.rs b/rust/parquet/src/schema/mod.rs index 351ce973371..f689db3c1b8 100644 --- a/rust/parquet/src/schema/mod.rs +++ b/rust/parquet/src/schema/mod.rs @@ -64,3 +64,4 @@ pub mod parser; pub mod printer; pub mod types; +pub mod visitor; diff --git a/rust/parquet/src/schema/visitor.rs b/rust/parquet/src/schema/visitor.rs new file mode 100644 index 00000000000..6970f9ed47a --- /dev/null +++ b/rust/parquet/src/schema/visitor.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::{LogicalType, Repetition}; +use crate::errors::ParquetError::General; +use crate::errors::Result; +use crate::schema::types::{Type, TypePtr}; + +/// A utility trait to help user to traverse against parquet type. +pub trait TypeVisitor { + /// Called when a primitive type hit. + fn visit_primitive(&mut self, primitive_type: TypePtr, context: C) -> Result; + + /// Default implementation when visiting a list. + /// + /// It checks list type definition and calls `visit_list_with_item` with extracted + /// item type. + /// + /// To fully understand this algorithm, please refer to + /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). + fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result { + match list_type.as_ref() { + Type::PrimitiveType { .. } => panic!( + "{:?} is a list type and can't be processed as primitive.", + list_type + ), + Type::GroupType { + basic_info: _, + fields, + } if fields.len() == 1 => { + let list_item = fields.first().unwrap(); + + match list_item.as_ref() { + Type::PrimitiveType { .. } => { + if list_item.get_basic_info().repetition() == Repetition::REPEATED + { + self.visit_list_with_item( + list_type.clone(), + list_item, + context, + ) + } else { + Err(General( + "Primitive element type of list must be repeated." + .to_string(), + )) + } + } + Type::GroupType { + basic_info: _, + fields, + } => { + if fields.len() == 1 + && list_item.name() != "array" + && list_item.name() != format!("{}_tuple", list_type.name()) + { + self.visit_list_with_item( + list_type.clone(), + fields.first().unwrap(), + context, + ) + } else { + self.visit_list_with_item( + list_type.clone(), + list_item, + context, + ) + } + } + } + } + _ => Err(General( + "Group element type of list can only contain one field.".to_string(), + )), + } + } + + /// Called when a struct type hit. + fn visit_struct(&mut self, struct_type: TypePtr, context: C) -> Result; + + /// Called when a map type hit. + fn visit_map(&mut self, map_type: TypePtr, context: C) -> Result; + + /// A utility method which detects input type and calls corresponding method. + fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result { + if cur_type.is_primitive() { + self.visit_primitive(cur_type, context) + } else { + match cur_type.get_basic_info().logical_type() { + LogicalType::LIST => self.visit_list(cur_type, context), + LogicalType::MAP | LogicalType::MAP_KEY_VALUE => { + self.visit_map(cur_type, context) + } + _ => self.visit_struct(cur_type, context), + } + } + } + + /// Called by `visit_list`. + fn visit_list_with_item( + &mut self, + list_type: TypePtr, + item_type: &Type, + context: C, + ) -> Result; +} + +#[cfg(test)] +mod tests { + use super::TypeVisitor; + use crate::basic::Type as PhysicalType; + use crate::errors::Result; + use crate::schema::parser::parse_message_type; + use crate::schema::types::{Type, TypePtr}; + use std::rc::Rc; + + struct TestVisitorContext {} + struct TestVisitor { + primitive_visited: bool, + struct_visited: bool, + list_visited: bool, + root_type: TypePtr, + } + + impl TypeVisitor for TestVisitor { + fn visit_primitive( + &mut self, + primitive_type: TypePtr, + _context: TestVisitorContext, + ) -> Result { + assert_eq!( + self.get_field_by_name(primitive_type.name()).as_ref(), + primitive_type.as_ref() + ); + self.primitive_visited = true; + Ok(true) + } + + fn visit_struct( + &mut self, + struct_type: TypePtr, + _context: TestVisitorContext, + ) -> Result { + assert_eq!( + self.get_field_by_name(struct_type.name()).as_ref(), + struct_type.as_ref() + ); + self.struct_visited = true; + Ok(true) + } + + fn visit_map( + &mut self, + _map_type: TypePtr, + _context: TestVisitorContext, + ) -> Result { + unimplemented!() + } + + fn visit_list_with_item( + &mut self, + list_type: TypePtr, + item_type: &Type, + _context: TestVisitorContext, + ) -> Result { + assert_eq!( + self.get_field_by_name(list_type.name()).as_ref(), + list_type.as_ref() + ); + assert_eq!("element", item_type.name()); + assert_eq!(PhysicalType::INT32, item_type.get_physical_type()); + self.list_visited = true; + Ok(true) + } + } + + impl TestVisitor { + fn new(root: TypePtr) -> Self { + Self { + primitive_visited: false, + struct_visited: false, + list_visited: false, + root_type: root, + } + } + + fn get_field_by_name(&self, name: &str) -> TypePtr { + self.root_type + .get_fields() + .iter() + .find(|t| t.name() == name) + .map(|t| t.clone()) + .unwrap() + } + } + + #[test] + fn test_visitor() { + let message_type = " + message spark_schema { + REQUIRED INT32 a; + OPTIONAL group inner_schema { + REQUIRED INT32 b; + REQUIRED DOUBLE c; + } + + OPTIONAL group e (LIST) { + REPEATED group list { + REQUIRED INT32 element; + } + } + "; + + let parquet_type = Rc::new(parse_message_type(&message_type).unwrap()); + + let mut visitor = TestVisitor::new(parquet_type.clone()); + for f in parquet_type.get_fields() { + let c = TestVisitorContext {}; + assert!(visitor.dispatch(f.clone(), c).unwrap()); + } + + assert!(visitor.struct_visited); + assert!(visitor.primitive_visited); + assert!(visitor.list_visited); + } +} From b7f27f005027f23b8eb21d75f794d88f2fd6d926 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 1 Jul 2019 13:00:54 -0700 Subject: [PATCH 19/52] ARROW-5358: [Rust] Implement equality check for ArrayData and Array This implements equality comparison for `Array` type which checks whether two arrays are identical in content. Besides the above, this adds two traits: `PrimitiveArrayOps` and `ListArrayOps`. The former exposes a few common operations between numeric arrays and boolean array, while the latter between list and binary arrays. Author: Chao Sun Closes #4643 from sunchao/ARROW-5358 and squashes the following commits: e40241b1 Fixes after rebasing 9a11efa4 Fix a bug in test 53ac33e6 Address comments 8663124c Replace expect with unwrap d3ffb27f ARROW-5358: Implement equality check for ArrayData and Array --- rust/arrow/src/array/array.rs | 95 ++++- rust/arrow/src/array/equal.rs | 741 ++++++++++++++++++++++++++++++++++ rust/arrow/src/array/mod.rs | 81 ++-- 3 files changed, 879 insertions(+), 38 deletions(-) create mode 100644 rust/arrow/src/array/equal.rs diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index f4af117f489..2c353d578f0 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::convert::From; +use std::convert::{From, TryFrom}; use std::fmt; use std::io::Write; use std::mem; @@ -27,6 +27,7 @@ use chrono::prelude::*; use super::*; use crate::buffer::{Buffer, MutableBuffer}; use crate::datatypes::*; +use crate::error::{ArrowError, Result}; use crate::memory; use crate::util::bit_util; @@ -41,7 +42,7 @@ const NANOSECONDS: i64 = 1_000_000_000; /// Trait for dealing with different types of array at runtime when the type of the /// array is not known in advance -pub trait Array: Send + Sync { +pub trait Array: Send + Sync + ArrayEqual { /// Returns the array as `Any` so that it can be downcast to a specific implementation fn as_any(&self) -> &Any; @@ -194,6 +195,45 @@ pub struct PrimitiveArray { raw_values: RawPtrBox, } +/// Common operations for primitive types, including numeric types and boolean type. +pub trait PrimitiveArrayOps { + fn values(&self) -> Buffer; + fn value(&self, i: usize) -> T::Native; +} + +// This is necessary when caller wants to access `PrimitiveArrayOps`'s methods with +// `ArrowPrimitiveType`. It doesn't have any implementation as the actual implementations +// are delegated to that of `ArrowNumericType` and `BooleanType`. +impl PrimitiveArrayOps for PrimitiveArray { + default fn values(&self) -> Buffer { + unimplemented!() + } + + default fn value(&self, _: usize) -> T::Native { + unimplemented!() + } +} + +impl PrimitiveArrayOps for PrimitiveArray { + fn values(&self) -> Buffer { + self.values() + } + + fn value(&self, i: usize) -> T::Native { + self.value(i) + } +} + +impl PrimitiveArrayOps for BooleanArray { + fn values(&self) -> Buffer { + self.values() + } + + fn value(&self, i: usize) -> bool { + self.value(i) + } +} + impl Array for PrimitiveArray { fn as_any(&self) -> &Any { self @@ -271,7 +311,6 @@ where /// /// If a data type cannot be converted to `NaiveDateTime`, a `None` is returned. /// A valid value is expected, thus the user should first check for validity. - /// TODO: extract constants into static variables pub fn value_as_datetime(&self, i: usize) -> Option { let v = i64::from(self.value(i)); match self.data_type() { @@ -651,6 +690,23 @@ impl From for PrimitiveArray { } } +/// Common operations for List types, currently `ListArray` and `BinaryArray`. +pub trait ListArrayOps { + fn value_offset_at(&self, i: usize) -> i32; +} + +impl ListArrayOps for ListArray { + fn value_offset_at(&self, i: usize) -> i32 { + self.value_offset_at(i) + } +} + +impl ListArrayOps for BinaryArray { + fn value_offset_at(&self, i: usize) -> i32 { + self.value_offset_at(i) + } +} + /// A list array where each element is a variable-sized sequence of values with the same /// type. pub struct ListArray { @@ -784,6 +840,16 @@ impl BinaryArray { self.value_offset_at(i + 1) - self.value_offset_at(i) } + /// Returns a clone of the value offset buffer + pub fn value_offsets(&self) -> Buffer { + self.data.buffers()[0].clone() + } + + /// Returns a clone of the value data buffer + pub fn value_data(&self) -> Buffer { + self.data.buffers()[1].clone() + } + #[inline] fn value_offset_at(&self, i: usize) -> i32 { unsafe { *self.value_offsets.get().offset(i as isize) } @@ -831,7 +897,7 @@ impl<'a> From> for BinaryArray { } } -impl<'a> From> for BinaryArray { +impl From> for BinaryArray { fn from(v: Vec<&[u8]>) -> Self { let mut offsets = Vec::with_capacity(v.len() + 1); let mut values = Vec::new(); @@ -851,6 +917,22 @@ impl<'a> From> for BinaryArray { } } +impl<'a> TryFrom>> for BinaryArray { + type Error = ArrowError; + + fn try_from(v: Vec>) -> Result { + let mut builder = BinaryBuilder::new(v.len()); + for val in v { + if let Some(s) = val { + builder.append_string(s)?; + } else { + builder.append(false)?; + } + } + Ok(builder.finish()) + } +} + /// Creates a `BinaryArray` from `List` array impl From for BinaryArray { fn from(v: ListArray) -> Self { @@ -907,6 +989,11 @@ impl StructArray { pub fn column(&self, pos: usize) -> &ArrayRef { &self.boxed_fields[pos] } + + /// Return the number of fields in this struct array + pub fn num_columns(&self) -> usize { + self.boxed_fields.len() + } } impl From for StructArray { diff --git a/rust/arrow/src/array/equal.rs b/rust/arrow/src/array/equal.rs new file mode 100644 index 00000000000..5f888ab5eac --- /dev/null +++ b/rust/arrow/src/array/equal.rs @@ -0,0 +1,741 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use crate::datatypes::*; +use crate::util::bit_util; + +/// Trait for `Array` equality. +pub trait ArrayEqual { + /// Returns true if this array is equal to the `other` array + fn equals(&self, other: &dyn Array) -> bool; + + /// Returns true if the range [start_idx, end_idx) is equal to + /// [other_start_idx, other_start_idx + end_idx - start_idx) in the `other` array + fn range_equals( + &self, + other: &dyn Array, + start_idx: usize, + end_idx: usize, + other_start_idx: usize, + ) -> bool; +} + +impl ArrayEqual for PrimitiveArray { + default fn equals(&self, other: &dyn Array) -> bool { + if !base_equal(&self.data(), &other.data()) { + return false; + } + + let value_buf = self.data_ref().buffers()[0].clone(); + let other_value_buf = other.data_ref().buffers()[0].clone(); + let byte_width = T::get_bit_width() / 8; + + if self.null_count() > 0 { + let values = value_buf.data(); + let other_values = other_value_buf.data(); + + for i in 0..self.len() { + if self.is_valid(i) { + let start = (i + self.offset()) * byte_width; + let data = &values[start..(start + byte_width)]; + let other_start = (i + other.offset()) * byte_width; + let other_data = + &other_values[other_start..(other_start + byte_width)]; + if data != other_data { + return false; + } + } + } + } else { + let start = self.offset() * byte_width; + let other_start = other.offset() * byte_width; + let len = self.len() * byte_width; + let data = &value_buf.data()[start..(start + len)]; + let other_data = &other_value_buf.data()[other_start..(other_start + len)]; + if data != other_data { + return false; + } + } + + true + } + + default fn range_equals( + &self, + other: &dyn Array, + start_idx: usize, + end_idx: usize, + other_start_idx: usize, + ) -> bool { + assert!(other_start_idx + (end_idx - start_idx) <= other.len()); + let other = other.as_any().downcast_ref::>().unwrap(); + + let mut j = other_start_idx; + for i in start_idx..end_idx { + let is_null = self.is_null(i); + let other_is_null = other.is_null(j); + if is_null != other_is_null || (!is_null && self.value(i) != other.value(j)) { + return false; + } + j += 1; + } + + true + } +} + +impl ArrayEqual for BooleanArray { + fn equals(&self, other: &dyn Array) -> bool { + if !base_equal(&self.data(), &other.data()) { + return false; + } + + let values = self.data_ref().buffers()[0].data(); + let other_values = other.data_ref().buffers()[0].data(); + + // TODO: we can do this more efficiently if all values are not-null + for i in 0..self.len() { + if self.is_valid(i) { + if bit_util::get_bit(values, i + self.offset()) + != bit_util::get_bit(other_values, i + other.offset()) + { + return false; + } + } + } + + true + } +} + +impl PartialEq for PrimitiveArray { + fn eq(&self, other: &PrimitiveArray) -> bool { + self.equals(other) + } +} + +impl ArrayEqual for ListArray { + fn equals(&self, other: &dyn Array) -> bool { + if !base_equal(&self.data(), &other.data()) { + return false; + } + + let other = other.as_any().downcast_ref::().unwrap(); + + if !value_offset_equal(self, other) { + return false; + } + + if !self.values().range_equals( + &*other.values(), + self.value_offset(0) as usize, + self.value_offset(self.len()) as usize, + other.value_offset(0) as usize, + ) { + return false; + } + + true + } + + fn range_equals( + &self, + other: &dyn Array, + start_idx: usize, + end_idx: usize, + other_start_idx: usize, + ) -> bool { + assert!(other_start_idx + (end_idx - start_idx) <= other.len()); + let other = other.as_any().downcast_ref::().unwrap(); + + let mut j = other_start_idx; + for i in start_idx..end_idx { + let is_null = self.is_null(i); + let other_is_null = other.is_null(j); + + if is_null != other_is_null { + return false; + } + + if is_null { + continue; + } + + let start_offset = self.value_offset(i) as usize; + let end_offset = self.value_offset(i + 1) as usize; + let other_start_offset = other.value_offset(j) as usize; + let other_end_offset = other.value_offset(j + 1) as usize; + + if end_offset - start_offset != other_end_offset - other_start_offset { + return false; + } + + if !self.values().range_equals( + &*other.values(), + start_offset, + end_offset, + other_start_offset, + ) { + return false; + } + + j += 1; + } + + true + } +} + +impl ArrayEqual for BinaryArray { + fn equals(&self, other: &dyn Array) -> bool { + if !base_equal(&self.data(), &other.data()) { + return false; + } + + let other = other.as_any().downcast_ref::().unwrap(); + + if !value_offset_equal(self, other) { + return false; + } + + // TODO: handle null & length == 0 case? + + let value_buf = self.value_data(); + let other_value_buf = other.value_data(); + let value_data = value_buf.data(); + let other_value_data = other_value_buf.data(); + + if self.null_count() == 0 { + // No offset in both - just do memcmp + if self.offset() == 0 && other.offset() == 0 { + let len = self.value_offset(self.len()) as usize; + return value_data[..len] == other_value_data[..len]; + } else { + let start = self.value_offset(0) as usize; + let other_start = other.value_offset(0) as usize; + let len = (self.value_offset(self.len()) - self.value_offset(0)) as usize; + return value_data[start..(start + len)] + == other_value_data[other_start..(other_start + len)]; + } + } else { + for i in 0..self.len() { + if self.is_null(i) { + continue; + } + + let start = self.value_offset(i) as usize; + let other_start = other.value_offset(i) as usize; + let len = self.value_length(i) as usize; + if value_data[start..(start + len)] + != other_value_data[other_start..(other_start + len)] + { + return false; + } + } + } + + true + } + + fn range_equals( + &self, + other: &dyn Array, + start_idx: usize, + end_idx: usize, + other_start_idx: usize, + ) -> bool { + assert!(other_start_idx + (end_idx - start_idx) <= other.len()); + let other = other.as_any().downcast_ref::().unwrap(); + + let mut j = other_start_idx; + for i in start_idx..end_idx { + let is_null = self.is_null(i); + let other_is_null = other.is_null(j); + + if is_null != other_is_null { + return false; + } + + if is_null { + continue; + } + + let start_offset = self.value_offset(i) as usize; + let end_offset = self.value_offset(i + 1) as usize; + let other_start_offset = other.value_offset(j) as usize; + let other_end_offset = other.value_offset(j + 1) as usize; + + if end_offset - start_offset != other_end_offset - other_start_offset { + return false; + } + + let value_buf = self.value_data(); + let other_value_buf = other.value_data(); + let value_data = value_buf.data(); + let other_value_data = other_value_buf.data(); + + if end_offset - start_offset > 0 { + let len = end_offset - start_offset; + if value_data[start_offset..(start_offset + len)] + != other_value_data[other_start_offset..(other_start_offset + len)] + { + return false; + } + } + + j += 1; + } + + true + } +} + +impl ArrayEqual for StructArray { + fn equals(&self, other: &dyn Array) -> bool { + if !base_equal(&self.data(), &other.data()) { + return false; + } + + let other = other.as_any().downcast_ref::().unwrap(); + + for i in 0..self.len() { + let is_null = self.is_null(i); + let other_is_null = other.is_null(i); + + if is_null != other_is_null { + return false; + } + + if is_null { + continue; + } + for j in 0..self.num_columns() { + if !self.column(j).range_equals(&**other.column(j), i, i + 1, i) { + return false; + } + } + } + + true + } + + fn range_equals( + &self, + other: &dyn Array, + start_idx: usize, + end_idx: usize, + other_start_idx: usize, + ) -> bool { + assert!(other_start_idx + (end_idx - start_idx) <= other.len()); + let other = other.as_any().downcast_ref::().unwrap(); + + let mut j = other_start_idx; + for i in start_idx..end_idx { + let is_null = self.is_null(i); + let other_is_null = other.is_null(i); + + if is_null != other_is_null { + return false; + } + + if is_null { + continue; + } + for k in 0..self.num_columns() { + if !self.column(k).range_equals(&**other.column(k), i, i + 1, j) { + return false; + } + } + + j += 1; + } + + true + } +} + +// Compare if the common basic fields between the two arrays are equal +fn base_equal(this: &ArrayDataRef, other: &ArrayDataRef) -> bool { + if this.data_type() != other.data_type() { + return false; + } + if this.len != other.len { + return false; + } + if this.null_count != other.null_count { + return false; + } + if this.null_count > 0 { + let null_bitmap = this.null_bitmap().as_ref().unwrap(); + let other_null_bitmap = other.null_bitmap().as_ref().unwrap(); + let null_buf = null_bitmap.bits.data(); + let other_null_buf = other_null_bitmap.bits.data(); + for i in 0..this.len() { + if bit_util::get_bit(null_buf, i + this.offset()) + != bit_util::get_bit(other_null_buf, i + other.offset()) + { + return false; + } + } + } + true +} + +// Compare if the value offsets are equal between the two list arrays +fn value_offset_equal(this: &T, other: &T) -> bool { + // Check if offsets differ + if this.offset() == 0 && other.offset() == 0 { + let offset_data = &this.data_ref().buffers()[0]; + let other_offset_data = &other.data_ref().buffers()[0]; + return offset_data.data()[0..((this.len() + 1) * 4)] + == other_offset_data.data()[0..((other.len() + 1) * 4)]; + } + + // The expensive case + for i in 0..this.len() + 1 { + if this.value_offset_at(i) - this.value_offset_at(0) + != other.value_offset_at(i) - other.value_offset_at(0) + { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::convert::TryFrom; + + use crate::error::Result; + + #[test] + fn test_primitive_equal() { + let a = Int32Array::from(vec![1, 2, 3]); + let b = Int32Array::from(vec![1, 2, 3]); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = Int32Array::from(vec![1, 2, 4]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where null_count > 0 + + let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let b = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = Int32Array::from(vec![Some(1), None, None, Some(3)]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + let b = Int32Array::from(vec![Some(1), None, Some(2), Some(4)]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where offset != 0 + + let a_slice = a.slice(1, 2); + let b_slice = b.slice(1, 2); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + } + + #[test] + fn test_boolean_equal() { + let a = BooleanArray::from(vec![false, false, true]); + let b = BooleanArray::from(vec![false, false, true]); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = BooleanArray::from(vec![false, false, false]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where null_count > 0 + + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where offset != 0 + + let a = BooleanArray::from(vec![false, true, false, true, false, false, true]); + let b = BooleanArray::from(vec![false, false, false, true, false, true, true]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + let a_slice = a.slice(2, 3); + let b_slice = b.slice(2, 3); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + + let a_slice = a.slice(3, 4); + let b_slice = b.slice(3, 4); + assert!(!a_slice.equals(&*b_slice)); + assert!(!b_slice.equals(&*a_slice)); + } + + #[test] + fn test_list_equal() { + let mut a_builder = ListBuilder::new(Int32Builder::new(10)); + let mut b_builder = ListBuilder::new(Int32Builder::new(10)); + + let a = create_list_array(&mut a_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 6])]) + .unwrap(); + let b = create_list_array(&mut b_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 6])]) + .unwrap(); + + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = create_list_array(&mut a_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 7])]) + .unwrap(); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where null_count > 0 + + let a = create_list_array( + &mut a_builder, + &[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None], + ) + .unwrap(); + let b = create_list_array( + &mut a_builder, + &[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None], + ) + .unwrap(); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = create_list_array( + &mut a_builder, + &[ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ], + ) + .unwrap(); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + let b = create_list_array( + &mut a_builder, + &[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None], + ) + .unwrap(); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where offset != 0 + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + assert!(!a_slice.equals(&*b_slice)); + assert!(!b_slice.equals(&*a_slice)); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + } + + #[test] + fn test_binary_equal() { + let a = BinaryArray::from(vec!["hello", "world"]); + let b = BinaryArray::from(vec!["hello", "world"]); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = BinaryArray::from(vec!["hello", "arrow"]); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where null_count > 0 + + let a = BinaryArray::try_from(vec![ + Some("hello"), + None, + None, + Some("world"), + None, + None, + ]) + .unwrap(); + + let b = BinaryArray::try_from(vec![ + Some("hello"), + None, + None, + Some("world"), + None, + None, + ]) + .unwrap(); + assert!(a.equals(&b)); + assert!(b.equals(&a)); + + let b = BinaryArray::try_from(vec![ + Some("hello"), + Some("foo"), + None, + Some("world"), + None, + None, + ]) + .unwrap(); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + let b = BinaryArray::try_from(vec![ + Some("hello"), + None, + None, + Some("arrow"), + None, + None, + ]) + .unwrap(); + assert!(!a.equals(&b)); + assert!(!b.equals(&a)); + + // Test the case where offset != 0 + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + assert!(!a_slice.equals(&*b_slice)); + assert!(!b_slice.equals(&*a_slice)); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + assert!(a_slice.equals(&*b_slice)); + assert!(b_slice.equals(&*a_slice)); + } + + #[test] + fn test_struct_equal() { + let string_builder = BinaryBuilder::new(5); + let int_builder = Int32Builder::new(5); + + let mut fields = Vec::new(); + let mut field_builders = Vec::new(); + fields.push(Field::new("f1", DataType::Utf8, false)); + field_builders.push(Box::new(string_builder) as Box); + fields.push(Field::new("f2", DataType::Int32, false)); + field_builders.push(Box::new(int_builder) as Box); + + let mut builder = StructBuilder::new(fields, field_builders); + + let a = create_struct_array( + &mut builder, + &[Some("joe"), None, None, Some("mark"), Some("doe")], + &[Some(1), Some(2), None, Some(4), Some(5)], + &[true, true, false, true, true], + ) + .unwrap(); + let b = create_struct_array( + &mut builder, + &[Some("joe"), None, None, Some("mark"), Some("doe")], + &[Some(1), Some(2), None, Some(4), Some(5)], + &[true, true, false, true, true], + ) + .unwrap(); + + assert!(a.equals(&b)); + assert!(b.equals(&a)); + } + + fn create_list_array<'a, U: AsRef<[i32]>, T: AsRef<[Option]>>( + builder: &'a mut ListBuilder, + data: T, + ) -> Result { + for d in data.as_ref() { + if let Some(v) = d { + builder.values().append_slice(v.as_ref())?; + builder.append(true)? + } else { + builder.append(false)? + } + } + Ok(builder.finish()) + } + + fn create_struct_array< + 'a, + T: AsRef<[Option<&'a str>]>, + U: AsRef<[Option]>, + V: AsRef<[bool]>, + >( + builder: &'a mut StructBuilder, + first: T, + second: U, + is_valid: V, + ) -> Result { + let string_builder = builder.field_builder::(0).unwrap(); + for v in first.as_ref() { + if let Some(s) = v { + string_builder.append_string(s)?; + } else { + string_builder.append_null()?; + } + } + + let int_builder = builder.field_builder::(1).unwrap(); + for v in second.as_ref() { + if let Some(i) = v { + int_builder.append_value(*i)?; + } else { + int_builder.append_null()?; + } + } + + for v in is_valid.as_ref() { + builder.append(*v)? + } + + Ok(builder.finish()) + } +} diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index aa14f0f2284..47e4219c865 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -57,6 +57,11 @@ mod array; mod builder; mod data; +mod equal; + +use crate::datatypes::*; + +// --------------------- Array & ArrayData --------------------- pub use self::array::Array; pub use self::array::ArrayRef; @@ -64,7 +69,41 @@ pub use self::data::ArrayData; pub use self::data::ArrayDataBuilder; pub use self::data::ArrayDataRef; -use crate::datatypes::*; +pub use self::array::BinaryArray; +pub use self::array::ListArray; +pub use self::array::PrimitiveArray; +pub use self::array::StructArray; + +pub(crate) use self::array::make_array; + +pub type BooleanArray = PrimitiveArray; +pub type Int8Array = PrimitiveArray; +pub type Int16Array = PrimitiveArray; +pub type Int32Array = PrimitiveArray; +pub type Int64Array = PrimitiveArray; +pub type UInt8Array = PrimitiveArray; +pub type UInt16Array = PrimitiveArray; +pub type UInt32Array = PrimitiveArray; +pub type UInt64Array = PrimitiveArray; +pub type Float32Array = PrimitiveArray; +pub type Float64Array = PrimitiveArray; + +pub type TimestampSecondArray = PrimitiveArray; +pub type TimestampMillisecondArray = PrimitiveArray; +pub type TimestampMicrosecondArray = PrimitiveArray; +pub type TimestampNanosecondArray = PrimitiveArray; +pub type Date32Array = PrimitiveArray; +pub type Date64Array = PrimitiveArray; +pub type Time32SecondArray = PrimitiveArray; +pub type Time32MillisecondArray = PrimitiveArray; +pub type Time64MicrosecondArray = PrimitiveArray; +pub type Time64NanosecondArray = PrimitiveArray; +// TODO add interval + +pub use self::array::ListArrayOps; +pub use self::array::PrimitiveArrayOps; + +// --------------------- Array Builder --------------------- pub use self::builder::BufferBuilder; pub use self::builder::BufferBuilderTrait; @@ -92,7 +131,12 @@ pub type Time32MillisecondBufferBuilder = BufferBuilder; pub type Time64MicrosecondBufferBuilder = BufferBuilder; pub type Time64NanosecondBufferBuilder = BufferBuilder; +pub use self::builder::ArrayBuilder; +pub use self::builder::BinaryBuilder; +pub use self::builder::ListBuilder; pub use self::builder::PrimitiveBuilder; +pub use self::builder::StructBuilder; + pub type BooleanBuilder = PrimitiveBuilder; pub type Int8Builder = PrimitiveBuilder; pub type Int16Builder = PrimitiveBuilder; @@ -116,37 +160,6 @@ pub type Time32MillisecondBuilder = PrimitiveBuilder; pub type Time64MicrosecondBuilder = PrimitiveBuilder; pub type Time64NanosecondBuilder = PrimitiveBuilder; -pub use self::builder::BinaryBuilder; -pub use self::builder::ListBuilder; -pub use self::builder::StructBuilder; - -pub use self::array::BinaryArray; -pub use self::array::ListArray; -pub use self::array::PrimitiveArray; -pub use self::array::StructArray; - -pub(crate) use self::array::make_array; - -pub type BooleanArray = PrimitiveArray; -pub type Int8Array = PrimitiveArray; -pub type Int16Array = PrimitiveArray; -pub type Int32Array = PrimitiveArray; -pub type Int64Array = PrimitiveArray; -pub type UInt8Array = PrimitiveArray; -pub type UInt16Array = PrimitiveArray; -pub type UInt32Array = PrimitiveArray; -pub type UInt64Array = PrimitiveArray; -pub type Float32Array = PrimitiveArray; -pub type Float64Array = PrimitiveArray; +// --------------------- Array Equality --------------------- -pub type TimestampSecondArray = PrimitiveArray; -pub type TimestampMillisecondArray = PrimitiveArray; -pub type TimestampMicrosecondArray = PrimitiveArray; -pub type TimestampNanosecondArray = PrimitiveArray; -pub type Date32Array = PrimitiveArray; -pub type Date64Array = PrimitiveArray; -pub type Time32SecondArray = PrimitiveArray; -pub type Time32MillisecondArray = PrimitiveArray; -pub type Time64MicrosecondArray = PrimitiveArray; -pub type Time64NanosecondArray = PrimitiveArray; -// TODO add interval +pub use self::equal::ArrayEqual; From 9145c1591aedbd141454cfc7b6aad5190c0fb30e Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Tue, 2 Jul 2019 05:59:34 +0900 Subject: [PATCH 20/52] ARROW-5820: [Release] Remove undefined variable check from verify script External shell scripts may refer unbound variable: /tmp/arrow-0.14.0.yum2X/apache-arrow-0.14.0/test-miniconda/etc/profile.d/conda.sh: line 55: PS1: unbound variable Author: Sutou Kouhei Closes #4773 from kou/release-verify-without-u and squashes the following commits: e27270390 Remove undefined variable check from verify script --- dev/release/verify-release-candidate.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index f694fb4efc0..cf8df75eb36 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -46,7 +46,6 @@ case $# in esac set -e -set -u set -x set -o pipefail From 52c04e4cae2244fc9a5d362c9e61ce8b3a65bd16 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 2 Jul 2019 06:01:07 +0900 Subject: [PATCH 21/52] ARROW-5816: [Release] Do not curl in background in verify-release-candidate.sh This is a temporary fix Author: Wes McKinney Closes #4768 from wesm/rc-no-curl-download-background and squashes the following commits: 171f92ec5 Do not curl in background --- dev/release/verify-release-candidate.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index cf8df75eb36..8b25d304c23 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -129,9 +129,8 @@ download_bintray_files() { --fail \ --location \ --output ${file} \ - https://dl.bintray.com/${BINTRAY_REPOSITORY}/${file} & + https://dl.bintray.com/${BINTRAY_REPOSITORY}/${file} done - wait } test_binary() { From 8ec2fd0d7dc686228e488dc7a56815c1ffd97c66 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 1 Jul 2019 17:39:07 -0500 Subject: [PATCH 22/52] ARROW-5564: [C++] Use uriparser from conda-forge Author: Antoine Pitrou Closes #4767 from pitrou/ARROW-5564-conda-uriparser and squashes the following commits: 3c422c947 ARROW-5564: Use uriparser from conda-forge --- ci/conda_env_cpp.yml | 1 + cpp/cmake_modules/ThirdpartyToolchain.cmake | 4 ---- dev/tasks/conda-recipes/arrow-cpp/meta.yaml | 2 ++ 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ci/conda_env_cpp.yml b/ci/conda_env_cpp.yml index e34d2bf299d..fd21ed8d3fa 100644 --- a/ci/conda_env_cpp.yml +++ b/ci/conda_env_cpp.yml @@ -36,5 +36,6 @@ python rapidjson snappy thrift-cpp>=0.11.0 +uriparser zlib zstd diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index f6677e0165e..b91312376fb 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -94,10 +94,6 @@ if(ARROW_DEPENDENCY_SOURCE STREQUAL "CONDA") endif() set(ARROW_ACTUAL_DEPENDENCY_SOURCE "SYSTEM") message(STATUS "Using CONDA_PREFIX for ARROW_PACKAGE_PREFIX: ${ARROW_PACKAGE_PREFIX}") - # ARROW-5564: Remove this when uriparser gets a conda package - if("${uriparser_SOURCE}" STREQUAL "") - set(uriparser_SOURCE "AUTO") - endif() else() set(ARROW_ACTUAL_DEPENDENCY_SOURCE "${ARROW_DEPENDENCY_SOURCE}") endif() diff --git a/dev/tasks/conda-recipes/arrow-cpp/meta.yaml b/dev/tasks/conda-recipes/arrow-cpp/meta.yaml index 3fd43f3c859..4638980f128 100644 --- a/dev/tasks/conda-recipes/arrow-cpp/meta.yaml +++ b/dev/tasks/conda-recipes/arrow-cpp/meta.yaml @@ -39,6 +39,7 @@ requirements: - re2 - snappy - thrift-cpp >=0.11 + - uriparser - zlib - zstd @@ -55,6 +56,7 @@ requirements: - python - re2 - snappy + - uriparser - zlib - zstd From 3cdd0cf0d6d99d0a84f6a1c96bcc158e1148ba16 Mon Sep 17 00:00:00 2001 From: Mark Mikofski Date: Mon, 1 Jul 2019 16:10:50 -0700 Subject: [PATCH 23/52] [Website] Fix incorrect expansion of "SIMD" term [SIMD](https://www.google.com/search?q=SIMD) should be single "instruction" not "input" --- site/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/site/index.html b/site/index.html index 4d5995ac54a..4aab88e5409 100644 --- a/site/index.html +++ b/site/index.html @@ -23,7 +23,7 @@

Fast

-

Apache Arrow™ enables execution engines to take advantage of the latest SIMD (Single input multiple data) operations included in modern processors, for native vectorized optimization of analytical data processing. Columnar layout is optimized for data locality for better performance on modern hardware like CPUs and GPUs.

+

Apache Arrow™ enables execution engines to take advantage of the latest SIMD (Single instruction, multiple data) operations included in modern processors, for native vectorized optimization of analytical data processing. Columnar layout is optimized for data locality for better performance on modern hardware like CPUs and GPUs.

The Arrow memory format supports zero-copy reads for lightning-fast data access without serialization overhead.

From 7adbe93014b7679fac1e1df1286767445fa27bce Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 1 Jul 2019 23:17:20 -0500 Subject: [PATCH 24/52] [Release] Test Arrow Flight in Windows release verification script This works for me with 0.14.0 rc0. On account of ARROW-5817 I had to check manually that the Python unit tests passed --- dev/release/verify-release-candidate.bat | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dev/release/verify-release-candidate.bat b/dev/release/verify-release-candidate.bat index 3f6d95c17cf..299297e95c5 100644 --- a/dev/release/verify-release-candidate.bat +++ b/dev/release/verify-release-candidate.bat @@ -75,6 +75,7 @@ cmake -G "%GENERATOR%" ^ -DGTest_SOURCE=BUNDLED ^ -DCMAKE_BUILD_TYPE=%CONFIGURATION% ^ -DARROW_CXXFLAGS="/MP" ^ + -DARROW_FLIGHT=ON ^ -DARROW_PYTHON=ON ^ -DARROW_PARQUET=ON ^ .. || exit /B @@ -94,11 +95,11 @@ ctest -VV || exit /B popd @rem Build and import pyarrow -@rem parquet-cpp has some additional runtime dependencies that we need to figure out -@rem see PARQUET-1018 pushd %ARROW_SOURCE%\python -python setup.py build_ext --inplace --with-parquet --bundle-arrow-cpp bdist_wheel || exit /B +set PYARROW_WITH_FLIGHT=1 +set PYARROW_WITH_PARQUET=1 +python setup.py build_ext --inplace --bundle-arrow-cpp bdist_wheel || exit /B py.test pyarrow -v -s --parquet || exit /B popd From e02ec4e80aec3c092a72890117574f34b931de30 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 1 Jul 2019 23:35:03 -0500 Subject: [PATCH 25/52] ARROW-5466: [Java][CI] Dockerize Java CI, run all JDK builds in single Travis entry Since OpenJDK9 has been superseded by OpenJDK11, it is not available in package repositories, so I'm not sure it's worth maintaining a build for this. I have pushed a pre-built Docker image for this to https://cloud.docker.com/u/ursalab/repository/docker/ursalab/arrow-ci-java-all-jdks Author: Wes McKinney Closes #4761 from wesm/java-dockerify and squashes the following commits: 8b8103729 Actually run Java unit tests cf568a28a Code review feedback 6cb5e3f70 Build Javadoc in docker-compose job d9c390092 Run all Java builds in a single Dockerized Travis CI entry --- .travis.yml | 26 +++----------- ci/docker_build_java.sh | 17 +++++++-- ...ipt_javadoc.sh => docker_java_test_all.sh} | 16 +++++---- docker-compose.yml | 12 +++++++ java/Dockerfile.all-jdks | 35 +++++++++++++++++++ 5 files changed, 75 insertions(+), 31 deletions(-) rename ci/{travis_script_javadoc.sh => docker_java_test_all.sh} (69%) create mode 100644 java/Dockerfile.all-jdks diff --git a/.travis.yml b/.travis.yml index 5dc901561e8..117f5dca683 100644 --- a/.travis.yml +++ b/.travis.yml @@ -199,32 +199,14 @@ matrix: - if [ $ARROW_CI_PYTHON_AFFECTED == "1" ]; then docker-compose pull python-manylinux1; fi script: - if [ $ARROW_CI_PYTHON_AFFECTED == "1" ]; then $TRAVIS_BUILD_DIR/ci/travis_script_manylinux.sh; fi - - name: "Java w/ OpenJDK 8" - language: java - os: linux - jdk: openjdk8 - before_script: - - if [ $ARROW_CI_JAVA_AFFECTED != "1" ]; then exit; fi - - $TRAVIS_BUILD_DIR/ci/travis_install_linux.sh - script: - - $TRAVIS_BUILD_DIR/ci/travis_script_java.sh - - $TRAVIS_BUILD_DIR/ci/travis_script_javadoc.sh - - name: "Java w/ OpenJDK 9" - language: java - os: linux - jdk: openjdk9 - before_script: - - if [ $ARROW_CI_JAVA_AFFECTED != "1" ]; then exit; fi - script: - - $TRAVIS_BUILD_DIR/ci/travis_script_java.sh - - name: "Java w/ OpenJDK 11" - language: java + - name: "Java OpenJDK8 and OpenJDK11" + language: cpp os: linux - jdk: openjdk11 before_script: - if [ $ARROW_CI_JAVA_AFFECTED != "1" ]; then exit; fi + - docker-compose pull java-all-jdks script: - - $TRAVIS_BUILD_DIR/ci/travis_script_java.sh + - docker-compose run java-all-jdks - name: "Integration w/ OpenJDK 8, conda-forge toolchain" language: java os: linux diff --git a/ci/docker_build_java.sh b/ci/docker_build_java.sh index f3dd3f1446b..e6516b77831 100755 --- a/ci/docker_build_java.sh +++ b/ci/docker_build_java.sh @@ -25,10 +25,23 @@ mkdir -p /build/java arrow_src=/build/java/arrow +# Remove any pre-existing artifacts +rm -rf $arrow_src + pushd /arrow - rsync -a header java format integration $arrow_src +rsync -a header java format integration $arrow_src popd +JAVA_ARGS= +if [ "$ARROW_JAVA_RUN_TESTS" != "1" ]; then + JAVA_ARGS=-DskipTests +fi + pushd $arrow_src/java - mvn -B -DskipTests -Drat.skip=true install +mvn -B $JAVA_ARGS -Drat.skip=true install + +if [ "$ARROW_JAVADOC" == "1" ]; then + export MAVEN_OPTS="$MAVEN_OPTS -Dorg.slf4j.simpleLogger.defaultLogLevel=warn" + mvn -B site +fi popd diff --git a/ci/travis_script_javadoc.sh b/ci/docker_java_test_all.sh similarity index 69% rename from ci/travis_script_javadoc.sh rename to ci/docker_java_test_all.sh index 755d4628f20..1466907d9c4 100755 --- a/ci/travis_script_javadoc.sh +++ b/ci/docker_java_test_all.sh @@ -1,5 +1,4 @@ #!/usr/bin/env bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -19,13 +18,16 @@ set -e -source $TRAVIS_BUILD_DIR/ci/travis_env_common.sh +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -JAVA_DIR=${TRAVIS_BUILD_DIR}/java +export ARROW_TEST_DATA=/arrow/testing/data -pushd $JAVA_DIR +export ARROW_JAVA_RUN_TESTS=1 -export MAVEN_OPTS="$MAVEN_OPTS -Dorg.slf4j.simpleLogger.defaultLogLevel=warn" -$TRAVIS_MVN -B site +export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 +export ARROW_JAVADOC=1 +bash $SOURCE_DIR/docker_build_java.sh -popd +export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 +export ARROW_JAVADOC=0 +bash $SOURCE_DIR/docker_build_java.sh diff --git a/docker-compose.yml b/docker-compose.yml index 2fa5ab47438..94171483c9a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -292,6 +292,18 @@ services: - .:/arrow:ro # ensures that docker won't contaminate the host directory - maven-cache:/root/.m2:delegated + java-all-jdks: + # Usage: + # docker-compose build java-all-jdks + # docker-compose run java-all-jdks + image: ursalab/arrow-ci-java-all-jdks:latest + build: + context: . + dockerfile: java/Dockerfile.all-jdks + volumes: + - .:/arrow:ro # ensures that docker won't contaminate the host directory + - maven-cache:/root/.m2:delegated + js: image: arrow:js build: diff --git a/java/Dockerfile.all-jdks b/java/Dockerfile.all-jdks new file mode 100644 index 00000000000..bf4e2afa227 --- /dev/null +++ b/java/Dockerfile.all-jdks @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM ubuntu:18.04 + +# install build essentials +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + wget \ + software-properties-common \ + ca-certificates \ + maven \ + rsync \ + tzdata \ + openjdk-8-jdk \ + openjdk-11-jdk && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# Test all supported JDKs +CMD ["arrow/ci/docker_java_test_all.sh"] From 74b9294bf63ff49818f2d6a72877139a1a540f60 Mon Sep 17 00:00:00 2001 From: Rok Date: Tue, 2 Jul 2019 10:21:41 +0200 Subject: [PATCH 26/52] ARROW-4453: [Python] Cython wrappers for SparseTensor Creating cython wrappers for SparseTensor. This is to resolve [ARROW-4453](https://issues.apache.org/jira/browse/ARROW-4453). Author: Rok Author: Antoine Pitrou Closes #4446 from rok/ARROW-4453 and squashes the following commits: db5d620fe Typo. 9e0363afe Polish code c31b8eb32 Enabling SparseTensor.Equals checks. 654002afe Partial review feedback implementation. e89edc620 Refactoring to_numpy methods. 3fcc1929e Add equality methods. 4a30487fc Set base object in to_numpy methods. 4eeae02d8 Cython wrapper for SparseTensor. --- cpp/src/arrow/compare.cc | 3 +- cpp/src/arrow/python/numpy_convert.cc | 173 ++++++++-- cpp/src/arrow/python/numpy_convert.h | 29 ++ cpp/src/arrow/python/pyarrow.cc | 38 +++ cpp/src/arrow/python/pyarrow.h | 14 + cpp/src/arrow/python/pyarrow_api.h | 18 + cpp/src/arrow/python/pyarrow_lib.h | 4 + cpp/src/arrow/python/serialize.cc | 2 +- cpp/src/arrow/sparse_tensor-test.cc | 39 +++ docs/source/python/extending.rst | 42 +++ python/pyarrow/__init__.pxd | 9 +- python/pyarrow/__init__.py | 1 + python/pyarrow/array.pxi | 98 ------ python/pyarrow/includes/libarrow.pxd | 60 ++++ python/pyarrow/lib.pxd | 30 ++ python/pyarrow/lib.pyx | 3 + python/pyarrow/public-api.pxi | 50 ++- python/pyarrow/tensor.pxi | 367 +++++++++++++++++++++ python/pyarrow/tests/test_sparse_tensor.py | 221 +++++++++++++ python/pyarrow/tests/test_tensor.py | 46 +-- 20 files changed, 1101 insertions(+), 146 deletions(-) create mode 100644 python/pyarrow/tensor.pxi create mode 100644 python/pyarrow/tests/test_sparse_tensor.py diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 12991b94aeb..4ae5d897917 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -1026,9 +1026,8 @@ struct SparseTensorEqualsImpl { const uint8_t* left_data = left.data()->data(); const uint8_t* right_data = right.data()->data(); - return memcmp(left_data, right_data, - static_cast(byte_width * left.non_zero_length())); + static_cast(byte_width * left.non_zero_length())) == 0; } }; diff --git a/cpp/src/arrow/python/numpy_convert.cc b/cpp/src/arrow/python/numpy_convert.cc index f7068b353be..515864ae287 100644 --- a/cpp/src/arrow/python/numpy_convert.cc +++ b/cpp/src/arrow/python/numpy_convert.cc @@ -25,8 +25,10 @@ #include #include "arrow/buffer.h" +#include "arrow/sparse_tensor.h" #include "arrow/tensor.h" #include "arrow/type.h" +#include "arrow/util/logging.h" #include "arrow/python/common.h" #include "arrow/python/pyarrow.h" @@ -186,7 +188,9 @@ Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { #undef TO_ARROW_TYPE_CASE -Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, std::shared_ptr* out) { +Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, + const std::vector& dim_names, + std::shared_ptr* out) { if (!PyArray_Check(ao)) { return Status::TypeError("Did not pass ndarray object"); } @@ -197,35 +201,29 @@ Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, std::shared_ptr* int ndim = PyArray_NDIM(ndarray); - // This is also holding the GIL, so don't already draw it. std::shared_ptr data = std::make_shared(ao); std::vector shape(ndim); std::vector strides(ndim); - { - PyAcquireGIL lock; - npy_intp* array_strides = PyArray_STRIDES(ndarray); - npy_intp* array_shape = PyArray_SHAPE(ndarray); - for (int i = 0; i < ndim; ++i) { - if (array_strides[i] < 0) { - return Status::Invalid("Negative ndarray strides not supported"); - } - shape[i] = array_shape[i]; - strides[i] = array_strides[i]; + npy_intp* array_strides = PyArray_STRIDES(ndarray); + npy_intp* array_shape = PyArray_SHAPE(ndarray); + for (int i = 0; i < ndim; ++i) { + if (array_strides[i] < 0) { + return Status::Invalid("Negative ndarray strides not supported"); } - - std::shared_ptr type; - RETURN_NOT_OK( - GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray)), &type)); - *out = std::make_shared(type, data, shape, strides); - return Status::OK(); + shape[i] = array_shape[i]; + strides[i] = array_strides[i]; } + + std::shared_ptr type; + RETURN_NOT_OK( + GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray)), &type)); + *out = std::make_shared(type, data, shape, strides, dim_names); + return Status::OK(); } Status TensorToNdarray(const std::shared_ptr& tensor, PyObject* base, PyObject** out) { - PyAcquireGIL lock; - int type_num; RETURN_NOT_OK(GetNumPyType(*tensor->type(), &type_num)); PyArray_Descr* dtype = PyArray_DescrNewFromType(type_num); @@ -274,5 +272,140 @@ Status TensorToNdarray(const std::shared_ptr& tensor, PyObject* base, return Status::OK(); } +// Wrap the dense data of a sparse tensor in a ndarray +static Status SparseTensorDataToNdarray(const SparseTensor& sparse_tensor, + std::vector data_shape, PyObject* base, + PyObject** out_data) { + int type_num_data; + RETURN_NOT_OK(GetNumPyType(*sparse_tensor.type(), &type_num_data)); + PyArray_Descr* dtype_data = PyArray_DescrNewFromType(type_num_data); + RETURN_IF_PYERROR(); + + const void* immutable_data = sparse_tensor.data()->data(); + // Remove const =( + void* mutable_data = const_cast(immutable_data); + int array_flags = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS; + if (sparse_tensor.is_mutable()) { + array_flags |= NPY_ARRAY_WRITEABLE; + } + + *out_data = PyArray_NewFromDescr(&PyArray_Type, dtype_data, + static_cast(data_shape.size()), data_shape.data(), + nullptr, mutable_data, array_flags, nullptr); + RETURN_IF_PYERROR() + Py_XINCREF(base); + PyArray_SetBaseObject(reinterpret_cast(*out_data), base); + return Status::OK(); +} + +Status SparseTensorCOOToNdarray(const std::shared_ptr& sparse_tensor, + PyObject* base, PyObject** out_data, + PyObject** out_coords) { + const auto& sparse_index = arrow::internal::checked_cast( + *sparse_tensor->sparse_index()); + + // Wrap tensor data + OwnedRef result_data; + RETURN_NOT_OK(SparseTensorDataToNdarray( + *sparse_tensor, {sparse_index.non_zero_length(), 1}, base, result_data.ref())); + + // Wrap indices + PyObject* result_coords; + RETURN_NOT_OK(TensorToNdarray(sparse_index.indices(), base, &result_coords)); + + *out_data = result_data.detach(); + *out_coords = result_coords; + return Status::OK(); +} + +Status SparseTensorCSRToNdarray(const std::shared_ptr& sparse_tensor, + PyObject* base, PyObject** out_data, + PyObject** out_indptr, PyObject** out_indices) { + const auto& sparse_index = arrow::internal::checked_cast( + *sparse_tensor->sparse_index()); + + // Wrap tensor data + OwnedRef result_data; + RETURN_NOT_OK(SparseTensorDataToNdarray( + *sparse_tensor, {sparse_index.non_zero_length(), 1}, base, result_data.ref())); + + // Wrap indices + OwnedRef result_indptr; + OwnedRef result_indices; + RETURN_NOT_OK(TensorToNdarray(sparse_index.indptr(), base, result_indptr.ref())); + RETURN_NOT_OK(TensorToNdarray(sparse_index.indices(), base, result_indices.ref())); + + *out_data = result_data.detach(); + *out_indptr = result_indptr.detach(); + *out_indices = result_indices.detach(); + return Status::OK(); +} + +Status NdarraysToSparseTensorCOO(MemoryPool* pool, PyObject* data_ao, PyObject* coords_ao, + const std::vector& shape, + const std::vector& dim_names, + std::shared_ptr* out) { + if (!PyArray_Check(data_ao) || !PyArray_Check(coords_ao)) { + return Status::TypeError("Did not pass ndarray object"); + } + + PyArrayObject* ndarray_data = reinterpret_cast(data_ao); + std::shared_ptr data = std::make_shared(data_ao); + std::shared_ptr type_data; + RETURN_NOT_OK(GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)), + &type_data)); + + std::shared_ptr coords; + RETURN_NOT_OK(NdarrayToTensor(pool, coords_ao, {}, &coords)); + ARROW_CHECK_EQ(coords->type_id(), Type::INT64); // Should be ensured by caller + + std::shared_ptr sparse_index = std::make_shared( + std::static_pointer_cast>(coords)); + *out = std::make_shared>(sparse_index, type_data, data, + shape, dim_names); + return Status::OK(); +} + +Status NdarraysToSparseTensorCSR(MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao, + PyObject* indices_ao, const std::vector& shape, + const std::vector& dim_names, + std::shared_ptr* out) { + if (!PyArray_Check(data_ao) || !PyArray_Check(indptr_ao) || + !PyArray_Check(indices_ao)) { + return Status::TypeError("Did not pass ndarray object"); + } + + PyArrayObject* ndarray_data = reinterpret_cast(data_ao); + std::shared_ptr data = std::make_shared(data_ao); + std::shared_ptr type_data; + RETURN_NOT_OK(GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)), + &type_data)); + + std::shared_ptr indptr, indices; + RETURN_NOT_OK(NdarrayToTensor(pool, indptr_ao, {}, &indptr)); + RETURN_NOT_OK(NdarrayToTensor(pool, indices_ao, {}, &indices)); + ARROW_CHECK_EQ(indptr->type_id(), Type::INT64); // Should be ensured by caller + ARROW_CHECK_EQ(indices->type_id(), Type::INT64); // Should be ensured by caller + + auto sparse_index = std::make_shared( + std::static_pointer_cast>(indptr), + std::static_pointer_cast>(indices)); + *out = std::make_shared>(sparse_index, type_data, data, + shape, dim_names); + return Status::OK(); +} + +Status TensorToSparseTensorCOO(const std::shared_ptr& tensor, + std::shared_ptr* out) { + *out = std::make_shared(*tensor); + return Status::OK(); +} + +Status TensorToSparseTensorCSR(const std::shared_ptr& tensor, + std::shared_ptr* out) { + *out = std::make_shared(*tensor); + return Status::OK(); +} + } // namespace py } // namespace arrow diff --git a/cpp/src/arrow/python/numpy_convert.h b/cpp/src/arrow/python/numpy_convert.h index dce5fe522d6..5fa1326f52b 100644 --- a/cpp/src/arrow/python/numpy_convert.h +++ b/cpp/src/arrow/python/numpy_convert.h @@ -25,9 +25,11 @@ #include #include +#include #include "arrow/buffer.h" #include "arrow/python/visibility.h" +#include "arrow/sparse_tensor.h" namespace arrow { @@ -63,11 +65,38 @@ Status GetTensorType(PyObject* dtype, std::shared_ptr* out); Status GetNumPyType(const DataType& type, int* type_num); ARROW_PYTHON_EXPORT Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, + const std::vector& dim_names, std::shared_ptr* out); ARROW_PYTHON_EXPORT Status TensorToNdarray(const std::shared_ptr& tensor, PyObject* base, PyObject** out); +ARROW_PYTHON_EXPORT Status +SparseTensorCOOToNdarray(const std::shared_ptr& sparse_tensor, + PyObject* base, PyObject** out_data, PyObject** out_coords); + +ARROW_PYTHON_EXPORT Status SparseTensorCSRToNdarray( + const std::shared_ptr& sparse_tensor, PyObject* base, + PyObject** out_data, PyObject** out_indptr, PyObject** out_indices); + +ARROW_PYTHON_EXPORT Status NdarraysToSparseTensorCOO( + MemoryPool* pool, PyObject* data_ao, PyObject* coords_ao, + const std::vector& shape, const std::vector& dim_names, + std::shared_ptr* out); + +ARROW_PYTHON_EXPORT Status NdarraysToSparseTensorCSR( + MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao, PyObject* indices_ao, + const std::vector& shape, const std::vector& dim_names, + std::shared_ptr* out); + +ARROW_PYTHON_EXPORT Status +TensorToSparseTensorCOO(const std::shared_ptr& tensor, + std::shared_ptr* csparse_tensor); + +ARROW_PYTHON_EXPORT Status +TensorToSparseTensorCSR(const std::shared_ptr& tensor, + std::shared_ptr* csparse_tensor); + } // namespace py } // namespace arrow diff --git a/cpp/src/arrow/python/pyarrow.cc b/cpp/src/arrow/python/pyarrow.cc index 1cedc549cfa..e037318bce2 100644 --- a/cpp/src/arrow/python/pyarrow.cc +++ b/cpp/src/arrow/python/pyarrow.cc @@ -123,6 +123,44 @@ PyObject* wrap_tensor(const std::shared_ptr& tensor) { return ::pyarrow_wrap_tensor(tensor); } +bool is_sparse_tensor_csr(PyObject* sparse_tensor) { + return ::pyarrow_is_sparse_tensor_csr(sparse_tensor) != 0; +} + +Status unwrap_sparse_tensor_csr(PyObject* sparse_tensor, + std::shared_ptr* out) { + *out = ::pyarrow_unwrap_sparse_tensor_csr(sparse_tensor); + if (*out) { + return Status::OK(); + } else { + return Status::Invalid( + "Could not unwrap SparseTensorCSR from the passed Python object."); + } +} + +PyObject* wrap_sparse_tensor_csr(const std::shared_ptr& sparse_tensor) { + return ::pyarrow_wrap_sparse_tensor_csr(sparse_tensor); +} + +bool is_sparse_tensor_coo(PyObject* sparse_tensor) { + return ::pyarrow_is_sparse_tensor_coo(sparse_tensor) != 0; +} + +Status unwrap_sparse_tensor_coo(PyObject* sparse_tensor, + std::shared_ptr* out) { + *out = ::pyarrow_unwrap_sparse_tensor_coo(sparse_tensor); + if (*out) { + return Status::OK(); + } else { + return Status::Invalid( + "Could not unwrap SparseTensorCOO from the passed Python object."); + } +} + +PyObject* wrap_sparse_tensor_coo(const std::shared_ptr& sparse_tensor) { + return ::pyarrow_wrap_sparse_tensor_coo(sparse_tensor); +} + bool is_column(PyObject* column) { return ::pyarrow_is_column(column) != 0; } Status unwrap_column(PyObject* column, std::shared_ptr* out) { diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h index ff5bf8f01dd..b4834f79f78 100644 --- a/cpp/src/arrow/python/pyarrow.h +++ b/cpp/src/arrow/python/pyarrow.h @@ -24,6 +24,8 @@ #include "arrow/python/visibility.h" +#include "arrow/sparse_tensor.h" + namespace arrow { class Array; @@ -67,6 +69,18 @@ ARROW_PYTHON_EXPORT bool is_tensor(PyObject* tensor); ARROW_PYTHON_EXPORT Status unwrap_tensor(PyObject* tensor, std::shared_ptr* out); ARROW_PYTHON_EXPORT PyObject* wrap_tensor(const std::shared_ptr& tensor); +ARROW_PYTHON_EXPORT bool is_sparse_tensor_coo(PyObject* sparse_tensor); +ARROW_PYTHON_EXPORT Status +unwrap_sparse_tensor_coo(PyObject* sparse_tensor, std::shared_ptr* out); +ARROW_PYTHON_EXPORT PyObject* wrap_sparse_tensor_coo( + const std::shared_ptr& sparse_tensor); + +ARROW_PYTHON_EXPORT bool is_sparse_tensor_csr(PyObject* sparse_tensor); +ARROW_PYTHON_EXPORT Status +unwrap_sparse_tensor_csr(PyObject* sparse_tensor, std::shared_ptr* out); +ARROW_PYTHON_EXPORT PyObject* wrap_sparse_tensor_csr( + const std::shared_ptr& sparse_tensor); + ARROW_PYTHON_EXPORT bool is_column(PyObject* column); ARROW_PYTHON_EXPORT Status unwrap_column(PyObject* column, std::shared_ptr* out); ARROW_PYTHON_EXPORT PyObject* wrap_column(const std::shared_ptr& column); diff --git a/cpp/src/arrow/python/pyarrow_api.h b/cpp/src/arrow/python/pyarrow_api.h index b76e9614a8a..2d8f71c8c5a 100644 --- a/cpp/src/arrow/python/pyarrow_api.h +++ b/cpp/src/arrow/python/pyarrow_api.h @@ -50,6 +50,10 @@ static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table)(std::shared_ptr #define pyarrow_wrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor)(std::shared_ptr< arrow::Tensor> const &) = 0; #define pyarrow_wrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_csr)(std::shared_ptr< arrow::SparseTensorCSR> const &) = 0; +#define pyarrow_wrap_sparse_tensor_csr __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_csr +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_coo)(std::shared_ptr< arrow::SparseTensorCOO> const &) = 0; +#define pyarrow_wrap_sparse_tensor_coo __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_coo static std::shared_ptr< arrow::Array> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array)(PyObject *) = 0; #define pyarrow_unwrap_array __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array static std::shared_ptr< arrow::RecordBatch> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch)(PyObject *) = 0; @@ -68,6 +72,10 @@ static std::shared_ptr< arrow::Table> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwra #define pyarrow_unwrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table static std::shared_ptr< arrow::Tensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor)(PyObject *) = 0; #define pyarrow_unwrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor +static std::shared_ptr< arrow::SparseTensorCSR> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_csr)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_tensor_csr __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_csr +static std::shared_ptr< arrow::SparseTensorCOO> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_coo)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_tensor_coo __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_coo static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status)(arrow::Status const &) = 0; #define pyarrow_internal_check_status __pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer)(PyObject *) = 0; @@ -84,6 +92,10 @@ static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar)(std::shared_pt #define pyarrow_wrap_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor)(PyObject *) = 0; #define pyarrow_is_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_csr)(PyObject *) = 0; +#define pyarrow_is_sparse_tensor_csr __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_csr +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_coo)(PyObject *) = 0; +#define pyarrow_is_sparse_tensor_coo __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_coo static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_column)(PyObject *) = 0; #define pyarrow_is_column __pyx_api_f_7pyarrow_3lib_pyarrow_is_column static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_table)(PyObject *) = 0; @@ -167,6 +179,8 @@ static int import_pyarrow__lib(void) { if (__Pyx_ImportFunction(module, "pyarrow_wrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema, "PyObject *(std::shared_ptr< arrow::Schema> const &)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_wrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table, "PyObject *(std::shared_ptr< arrow::Table> const &)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_wrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor, "PyObject *(std::shared_ptr< arrow::Tensor> const &)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_tensor_csr", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_csr, "PyObject *(std::shared_ptr< arrow::SparseTensorCSR> const &)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_tensor_coo", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_csr, "PyObject *(std::shared_ptr< arrow::SparseTensorCOO> const &)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_unwrap_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array, "std::shared_ptr< arrow::Array> (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_unwrap_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch, "std::shared_ptr< arrow::RecordBatch> (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_unwrap_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer, "std::shared_ptr< arrow::Buffer> (PyObject *)") < 0) goto bad; @@ -176,6 +190,8 @@ static int import_pyarrow__lib(void) { if (__Pyx_ImportFunction(module, "pyarrow_unwrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema, "std::shared_ptr< arrow::Schema> (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_unwrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table, "std::shared_ptr< arrow::Table> (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_unwrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor, "std::shared_ptr< arrow::Tensor> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_tensor_csr", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_csr, "std::shared_ptr< arrow::SparseTensorCSR> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_tensor_coo", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_coo, "std::shared_ptr< arrow::SparseTensorCOO> (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_internal_check_status", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status, "int (arrow::Status const &)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer, "int (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type, "int (PyObject *)") < 0) goto bad; @@ -184,6 +200,8 @@ static int import_pyarrow__lib(void) { if (__Pyx_ImportFunction(module, "pyarrow_is_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_array, "int (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_wrap_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar, "PyObject *(std::shared_ptr< arrow::Scalar> const &)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_tensor_csr", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_csr, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_tensor_coo", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_tensor_coo, "int (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_column", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_column, "int (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_table, "int (PyObject *)") < 0) goto bad; if (__Pyx_ImportFunction(module, "pyarrow_is_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_batch, "int (PyObject *)") < 0) goto bad; diff --git a/cpp/src/arrow/python/pyarrow_lib.h b/cpp/src/arrow/python/pyarrow_lib.h index 5f5fc4c6b6f..a4bc1039ee8 100644 --- a/cpp/src/arrow/python/pyarrow_lib.h +++ b/cpp/src/arrow/python/pyarrow_lib.h @@ -48,6 +48,8 @@ __PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer(std __PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_schema(std::shared_ptr< arrow::Schema> const &); __PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_table(std::shared_ptr< arrow::Table> const &); __PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_tensor(std::shared_ptr< arrow::Tensor> const &); +__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_coo(std::shared_ptr< arrow::SparseTensorCOO> const &); +__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_tensor_csr(std::shared_ptr< arrow::SparseTensorCSR> const &); __PYX_EXTERN_C std::shared_ptr< arrow::Array> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_array(PyObject *); __PYX_EXTERN_C std::shared_ptr< arrow::RecordBatch> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_batch(PyObject *); __PYX_EXTERN_C std::shared_ptr< arrow::Buffer> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_buffer(PyObject *); @@ -57,6 +59,8 @@ __PYX_EXTERN_C std::shared_ptr< arrow::Field> __pyx_f_7pyarrow_3lib_pyarrow_unw __PYX_EXTERN_C std::shared_ptr< arrow::Schema> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_schema(PyObject *); __PYX_EXTERN_C std::shared_ptr< arrow::Table> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_table(PyObject *); __PYX_EXTERN_C std::shared_ptr< arrow::Tensor> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_tensor(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseTensorCOO> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_coo(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseTensorCSR> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_tensor_csr(PyObject *); #endif /* !__PYX_HAVE_API__pyarrow__lib */ diff --git a/cpp/src/arrow/python/serialize.cc b/cpp/src/arrow/python/serialize.cc index 8ff0e01480f..d93e3954e41 100644 --- a/cpp/src/arrow/python/serialize.cc +++ b/cpp/src/arrow/python/serialize.cc @@ -515,7 +515,7 @@ Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* bui builder->AppendNdarray(static_cast(blobs_out->ndarrays.size()))); std::shared_ptr tensor; RETURN_NOT_OK(NdarrayToTensor(default_memory_pool(), - reinterpret_cast(array), &tensor)); + reinterpret_cast(array), {}, &tensor)); blobs_out->ndarrays.push_back(tensor); } break; default: { diff --git a/cpp/src/arrow/sparse_tensor-test.cc b/cpp/src/arrow/sparse_tensor-test.cc index daff0194fe5..69ec4ca5c60 100644 --- a/cpp/src/arrow/sparse_tensor-test.cc +++ b/cpp/src/arrow/sparse_tensor-test.cc @@ -182,6 +182,25 @@ TEST(TestSparseCOOTensor, CreationFromNonContiguousTensor) { AssertCOOIndex(sidx, 11, {1, 2, 3}); } +TEST(TestSparseCOOTensor, TensorEquality) { + std::vector shape = {2, 3, 4}; + std::vector values1 = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, + 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16}; + std::vector values2 = {0, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, + 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16}; + std::shared_ptr buffer1 = Buffer::Wrap(values1); + std::shared_ptr buffer2 = Buffer::Wrap(values2); + NumericTensor tensor1(buffer1, shape); + NumericTensor tensor2(buffer1, shape); + NumericTensor tensor3(buffer2, shape); + SparseTensorImpl st1(tensor1); + SparseTensorImpl st2(tensor2); + SparseTensorImpl st3(tensor3); + + ASSERT_TRUE(st1.Equals(st2)); + ASSERT_TRUE(!st1.Equals(st3)); +} + TEST(TestSparseCSRMatrix, CreationFromNumericTensor2D) { std::vector shape = {6, 4}; std::vector values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, @@ -269,4 +288,24 @@ TEST(TestSparseCSRMatrix, CreationFromNonContiguousTensor) { ASSERT_EQ(std::vector({0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}), indices_values); } +TEST(TestSparseCSRMatrix, TensorEquality) { + std::vector shape = {6, 4}; + std::vector values1 = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, + 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16}; + std::vector values2 = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + std::shared_ptr buffer1 = Buffer::Wrap(values1); + std::shared_ptr buffer2 = Buffer::Wrap(values2); + NumericTensor tensor1(buffer1, shape); + NumericTensor tensor2(buffer1, shape); + NumericTensor tensor3(buffer2, shape); + SparseTensorImpl st1(tensor1); + SparseTensorImpl st2(tensor2); + SparseTensorImpl st3(tensor3); + + ASSERT_TRUE(st1.Equals(st2)); + ASSERT_TRUE(!st1.Equals(st3)); +} + } // namespace arrow diff --git a/docs/source/python/extending.rst b/docs/source/python/extending.rst index 6b5c9ce1902..f15b1bedbac 100644 --- a/docs/source/python/extending.rst +++ b/docs/source/python/extending.rst @@ -116,6 +116,16 @@ C++ objects. Return whether *obj* wraps an Arrow C++ :class:`Tensor` pointer; in other words, whether *obj* is a :py:class:`pyarrow.Tensor` instance. +.. function:: bool is_sparse_tensor_coo(PyObject* obj) + + Return whether *obj* wraps an Arrow C++ :class:`SparseTensorCOO` pointer; + in other words, whether *obj* is a :py:class:`pyarrow.SparseTensorCOO` instance. + +.. function:: bool is_sparse_tensor_csr(PyObject* obj) + + Return whether *obj* wraps an Arrow C++ :class:`SparseTensorCSR` pointer; + in other words, whether *obj* is a :py:class:`pyarrow.SparseTensorCSR` instance. + The following functions expect a pyarrow object, unwrap the underlying Arrow C++ API pointer, and put it in the *out* parameter. The returned :class:`Status` object must be inspected first to know whether any error @@ -157,6 +167,14 @@ occurred. If successful, *out* is guaranteed to be non-NULL. Unwrap the Arrow C++ :class:`Tensor` pointer from *obj* and put it in *out*. +.. function:: Status unwrap_sparse_tensor_coo(PyObject* obj, std::shared_ptr* out) + + Unwrap the Arrow C++ :class:`SparseTensorCOO` pointer from *obj* and put it in *out*. + +.. function:: Status unwrap_sparse_tensor_csr(PyObject* obj, std::shared_ptr* out) + + Unwrap the Arrow C++ :class:`SparseTensorCSR` pointer from *obj* and put it in *out*. + The following functions take an Arrow C++ API pointer and wrap it in a pyarray object of the corresponding type. A new reference is returned. On error, NULL is returned and a Python exception is set. @@ -197,6 +215,14 @@ On error, NULL is returned and a Python exception is set. Wrap the Arrow C++ *tensor* in a :py:class:`pyarrow.Tensor` instance. +.. function:: PyObject* wrap_sparse_tensor_coo(const std::shared_ptr& sparse_tensor) + + Wrap the Arrow C++ *COO sparse tensor* in a :py:class:`pyarrow.SparseTensorCOO` instance. + +.. function:: PyObject* wrap_sparse_tensor_csr(const std::shared_ptr& sparse_tensor) + + Wrap the Arrow C++ *CSR sparse tensor* in a :py:class:`pyarrow.SparseTensorCSR` instance. + Cython API ---------- @@ -257,6 +283,14 @@ an exception) if the input is not of the right type. Unwrap the Arrow C++ :cpp:class:`Tensor` pointer from *obj*. +.. function:: pyarrow_unwrap_sparse_tensor_coo(obj) -> shared_ptr[CSparseTensorCOO] + + Unwrap the Arrow C++ :cpp:class:`SparseTensorCOO` pointer from *obj*. + +.. function:: pyarrow_unwrap_sparse_tensor_csr(obj) -> shared_ptr[CSparseTensorCSR] + + Unwrap the Arrow C++ :cpp:class:`SparseTensorCSR` pointer from *obj*. + The following functions take a Arrow C++ API pointer and wrap it in a pyarray object of the corresponding type. An exception is raised on error. @@ -300,6 +334,14 @@ pyarray object of the corresponding type. An exception is raised on error. Wrap the Arrow C++ *tensor* in a Python :class:`pyarrow.Tensor` instance. +.. function:: pyarrow_wrap_sparse_tensor_coo(sp_array: const shared_ptr[CSparseTensorCOO]& sparse_tensor) -> object + + Wrap the Arrow C++ *COO sparse tensor* in a Python :class:`pyarrow.SparseTensorCOO` instance. + +.. function:: pyarrow_wrap_sparse_tensor_csr(sp_array: const shared_ptr[CSparseTensorCSR]& sparse_tensor) -> object + + Wrap the Arrow C++ *CSR sparse tensor* in a Python :class:`pyarrow.SparseTensorCSR` instance. + Example ~~~~~~~ diff --git a/python/pyarrow/__init__.pxd b/python/pyarrow/__init__.pxd index 95cea5ca4fc..432880556cc 100644 --- a/python/pyarrow/__init__.pxd +++ b/python/pyarrow/__init__.pxd @@ -20,8 +20,9 @@ from __future__ import absolute_import from libcpp.memory cimport shared_ptr from pyarrow.includes.libarrow cimport (CArray, CBuffer, CColumn, CDataType, CField, CRecordBatch, CSchema, - CTable, CTensor) - + CTable, CTensor, + CSparseTensorCSR, CSparseTensorCOO) +from pyarrow.compat import frombytes cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": cdef int import_pyarrow() except -1 @@ -31,6 +32,10 @@ cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": cdef object wrap_schema(const shared_ptr[CSchema]& schema) cdef object wrap_array(const shared_ptr[CArray]& sp_array) cdef object wrap_tensor(const shared_ptr[CTensor]& sp_tensor) + cdef object wrap_sparse_tensor_coo( + const shared_ptr[CSparseTensorCOO]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csr( + const shared_ptr[CSparseTensorCSR]& sp_sparse_tensor) cdef object wrap_column(const shared_ptr[CColumn]& ccolumn) cdef object wrap_table(const shared_ptr[CTable]& ctable) cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 487065c2892..bbbd91a9508 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -66,6 +66,7 @@ def parse_git(root, **kwargs): schema, Array, Tensor, array, chunked_array, column, table, + SparseTensorCSR, SparseTensorCOO, infer_type, from_numpy_dtype, NullArray, NumericArray, IntegerArray, FloatingPointArray, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 5ae178d8953..15905a18507 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -870,104 +870,6 @@ cdef class Array(_PandasConvertible): return res -cdef class Tensor: - """ - A n-dimensional array a.k.a Tensor. - """ - - def __init__(self): - raise TypeError("Do not call Tensor's constructor directly, use one " - "of the `pyarrow.Tensor.from_*` functions instead.") - - cdef void init(self, const shared_ptr[CTensor]& sp_tensor): - self.sp_tensor = sp_tensor - self.tp = sp_tensor.get() - self.type = pyarrow_wrap_data_type(self.tp.type()) - - def __repr__(self): - return """ -type: {0.type} -shape: {0.shape} -strides: {0.strides}""".format(self) - - @staticmethod - def from_numpy(obj): - cdef shared_ptr[CTensor] ctensor - with nogil: - check_status(NdarrayToTensor(c_default_memory_pool(), obj, - &ctensor)) - return pyarrow_wrap_tensor(ctensor) - - def to_numpy(self): - """ - Convert arrow::Tensor to numpy.ndarray with zero copy - """ - cdef PyObject* out - - with nogil: - check_status(TensorToNdarray(self.sp_tensor, self, &out)) - return PyObject_to_object(out) - - def equals(self, Tensor other): - """ - Return true if the tensors contains exactly equal data - """ - return self.tp.Equals(deref(other.tp)) - - def __eq__(self, other): - if isinstance(other, Tensor): - return self.equals(other) - else: - return NotImplemented - - @property - def is_mutable(self): - return self.tp.is_mutable() - - @property - def is_contiguous(self): - return self.tp.is_contiguous() - - @property - def ndim(self): - return self.tp.ndim() - - @property - def size(self): - return self.tp.size() - - @property - def shape(self): - # Cython knows how to convert a vector[T] to a Python list - return tuple(self.tp.shape()) - - @property - def strides(self): - return tuple(self.tp.strides()) - - def __getbuffer__(self, cp.Py_buffer* buffer, int flags): - buffer.buf = self.tp.data().get().data() - pep3118_format = self.type.pep3118_format - if pep3118_format is None: - raise NotImplementedError("type %s not supported for buffer " - "protocol" % (self.type,)) - buffer.format = pep3118_format - buffer.itemsize = self.type.bit_width // 8 - buffer.internal = NULL - buffer.len = self.tp.size() * buffer.itemsize - buffer.ndim = self.tp.ndim() - buffer.obj = self - if self.tp.is_mutable(): - buffer.readonly = 0 - else: - buffer.readonly = 1 - # NOTE: This assumes Py_ssize_t == int64_t, and that the shape - # and strides arrays lifetime is tied to the tensor's - buffer.shape = &self.tp.shape()[0] - buffer.strides = &self.tp.strides()[0] - buffer.suboffsets = NULL - - cdef wrap_array_output(PyObject* output): cdef object obj = PyObject_to_object(output) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 8798834b5fd..93a75945ce3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -593,6 +593,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: int64_t size() int ndim() + const vector[c_string]& dim_names() const c_string& dim_name(int i) c_bool is_mutable() @@ -600,6 +601,38 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: Type type_id() c_bool Equals(const CTensor& other) + cdef cppclass CSparseTensorCOO" arrow::SparseTensorCOO": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseTensorCOO& other) + + cdef cppclass CSparseTensorCSR" arrow::SparseTensorCSR": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseTensorCSR& other) + cdef cppclass CScalar" arrow::Scalar": shared_ptr[CDataType] type @@ -1202,11 +1235,38 @@ cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil: shared_ptr[CChunkedArray]* out) CStatus NdarrayToTensor(CMemoryPool* pool, object ao, + const vector[c_string]& dim_names, shared_ptr[CTensor]* out) CStatus TensorToNdarray(const shared_ptr[CTensor]& tensor, object base, PyObject** out) + CStatus SparseTensorCOOToNdarray( + const shared_ptr[CSparseTensorCOO]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_coords) + + CStatus SparseTensorCSRToNdarray( + const shared_ptr[CSparseTensorCSR]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_indptr, PyObject** out_indices) + + CStatus NdarraysToSparseTensorCOO(CMemoryPool* pool, object data_ao, + object coords_ao, + const vector[int64_t]& shape, + const vector[c_string]& dim_names, + shared_ptr[CSparseTensorCOO]* out) + + CStatus NdarraysToSparseTensorCSR(CMemoryPool* pool, object data_ao, + object indptr_ao, object indices_ao, + const vector[int64_t]& shape, + const vector[c_string]& dim_names, + shared_ptr[CSparseTensorCSR]* out) + + CStatus TensorToSparseTensorCOO(shared_ptr[CTensor], + shared_ptr[CSparseTensorCOO]* out) + + CStatus TensorToSparseTensorCSR(shared_ptr[CTensor], + shared_ptr[CSparseTensorCSR]* out) + CStatus ConvertArrayToPandas(const PandasOptions& options, const shared_ptr[CArray]& arr, object py_ref, PyObject** out) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 79ab9478b16..898c70a4bf7 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -231,6 +231,28 @@ cdef class Tensor: cdef void init(self, const shared_ptr[CTensor]& sp_tensor) +cdef class SparseTensorCSR: + cdef: + shared_ptr[CSparseTensorCSR] sp_sparse_tensor + CSparseTensorCSR* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseTensorCSR]& sp_sparse_tensor) + + +cdef class SparseTensorCOO: + cdef: + shared_ptr[CSparseTensorCOO] sp_sparse_tensor + CSparseTensorCOO* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseTensorCOO]& sp_sparse_tensor) + + cdef class NullArray(Array): pass @@ -452,6 +474,10 @@ cdef public object pyarrow_wrap_resizable_buffer( cdef public object pyarrow_wrap_schema(const shared_ptr[CSchema]& type) cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) +cdef public object pyarrow_wrap_sparse_tensor_coo( + const shared_ptr[CSparseTensorCOO]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_tensor_csr( + const shared_ptr[CSparseTensorCSR]& sp_sparse_tensor) cdef public shared_ptr[CArray] pyarrow_unwrap_array(object array) cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) @@ -462,3 +488,7 @@ cdef public shared_ptr[CField] pyarrow_unwrap_field(object field) cdef public shared_ptr[CSchema] pyarrow_unwrap_schema(object schema) cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) +cdef public shared_ptr[CSparseTensorCOO] pyarrow_unwrap_sparse_tensor_coo( + object sparse_tensor) +cdef public shared_ptr[CSparseTensorCSR] pyarrow_unwrap_sparse_tensor_csr( + object sparse_tensor) diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 783e2b2731a..2da5a8301bc 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -121,6 +121,9 @@ include "builder.pxi" # Column, Table, Record Batch include "table.pxi" +# Tensors +include "tensor.pxi" + # File IO include "io.pxi" include "io-hdfs.pxi" diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 33bc8031804..05c07748f17 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -18,7 +18,8 @@ from libcpp.memory cimport shared_ptr from pyarrow.includes.libarrow cimport (CArray, CColumn, CDataType, CField, CRecordBatch, CSchema, - CTable, CTensor) + CTable, CTensor, + CSparseTensorCSR, CSparseTensorCOO) # You cannot assign something to a dereferenced pointer in Cython thus these # methods don't use Status to indicate a successful operation. @@ -225,6 +226,7 @@ cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): scalar.init(sp_scalar) return scalar + cdef api bint pyarrow_is_tensor(object tensor): return isinstance(tensor, Tensor) @@ -248,6 +250,52 @@ cdef api object pyarrow_wrap_tensor( return tensor +cdef api bint pyarrow_is_sparse_tensor_coo(object sparse_tensor): + return isinstance(sparse_tensor, SparseTensorCOO) + +cdef api shared_ptr[CSparseTensorCOO] pyarrow_unwrap_sparse_tensor_coo( + object sparse_tensor): + cdef SparseTensorCOO sten + if pyarrow_is_sparse_tensor_coo(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseTensorCOO]() + +cdef api object pyarrow_wrap_sparse_tensor_coo( + const shared_ptr[CSparseTensorCOO]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseTensorCOO was NULL') + + cdef SparseTensorCOO sparse_tensor = SparseTensorCOO.__new__( + SparseTensorCOO) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_tensor_csr(object sparse_tensor): + return isinstance(sparse_tensor, SparseTensorCSR) + +cdef api shared_ptr[CSparseTensorCSR] pyarrow_unwrap_sparse_tensor_csr( + object sparse_tensor): + cdef SparseTensorCSR sten + if pyarrow_is_sparse_tensor_csr(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseTensorCSR]() + +cdef api object pyarrow_wrap_sparse_tensor_csr( + const shared_ptr[CSparseTensorCSR]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseTensorCSR was NULL') + + cdef SparseTensorCSR sparse_tensor = SparseTensorCSR.__new__( + SparseTensorCSR) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + cdef api bint pyarrow_is_column(object column): return isinstance(column, Column) diff --git a/python/pyarrow/tensor.pxi b/python/pyarrow/tensor.pxi new file mode 100644 index 00000000000..17554e61740 --- /dev/null +++ b/python/pyarrow/tensor.pxi @@ -0,0 +1,367 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +cdef class Tensor: + """ + A n-dimensional array a.k.a Tensor. + """ + + def __init__(self): + raise TypeError("Do not call Tensor's constructor directly, use one " + "of the `pyarrow.Tensor.from_*` functions instead.") + + cdef void init(self, const shared_ptr[CTensor]& sp_tensor): + self.sp_tensor = sp_tensor + self.tp = sp_tensor.get() + self.type = pyarrow_wrap_data_type(self.tp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape} +strides: {0.strides}""".format(self) + + @staticmethod + def from_numpy(obj, dim_names=None): + cdef: + vector[c_string] c_dim_names + shared_ptr[CTensor] ctensor + + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + check_status(NdarrayToTensor(c_default_memory_pool(), obj, + c_dim_names, &ctensor)) + return pyarrow_wrap_tensor(ctensor) + + def to_numpy(self): + """ + Convert arrow::Tensor to numpy.ndarray with zero copy + """ + cdef PyObject* out + + check_status(TensorToNdarray(self.sp_tensor, self, &out)) + return PyObject_to_object(out) + + def equals(self, Tensor other): + """ + Return true if the tensors contains exactly equal data + """ + return self.tp.Equals(deref(other.tp)) + + def __eq__(self, other): + if isinstance(other, Tensor): + return self.equals(other) + else: + return NotImplemented + + def dim_name(self, i): + return frombytes(self.tp.dim_name(i)) + + @property + def dim_names(self): + return [frombytes(x) for x in tuple(self.tp.dim_names())] + + @property + def is_mutable(self): + return self.tp.is_mutable() + + @property + def is_contiguous(self): + return self.tp.is_contiguous() + + @property + def ndim(self): + return self.tp.ndim() + + @property + def size(self): + return self.tp.size() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.tp.shape()) + + @property + def strides(self): + return tuple(self.tp.strides()) + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + buffer.buf = self.tp.data().get().data() + pep3118_format = self.type.pep3118_format + if pep3118_format is None: + raise NotImplementedError("type %s not supported for buffer " + "protocol" % (self.type,)) + buffer.format = pep3118_format + buffer.itemsize = self.type.bit_width // 8 + buffer.internal = NULL + buffer.len = self.tp.size() * buffer.itemsize + buffer.ndim = self.tp.ndim() + buffer.obj = self + if self.tp.is_mutable(): + buffer.readonly = 0 + else: + buffer.readonly = 1 + # NOTE: This assumes Py_ssize_t == int64_t, and that the shape + # and strides arrays lifetime is tied to the tensor's + buffer.shape = &self.tp.shape()[0] + buffer.strides = &self.tp.strides()[0] + buffer.suboffsets = NULL + + +cdef class SparseTensorCOO: + """ + A sparse COO tensor. + """ + + def __init__(self): + raise TypeError("Do not call SparseTensorCOO's constructor directly, " + "use one of the `pyarrow.SparseTensorCOO.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseTensorCOO]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseTensorCOO + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, coords, shape, dim_names=None): + """ + Create arrow::SparseTensorCOO from numpy.ndarrays + """ + cdef shared_ptr[CSparseTensorCOO] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseTensorCOO indices + coords = np.require(coords, dtype='i8', requirements='F') + if coords.ndim != 2: + raise ValueError("Expected 2-dimensional array for " + "SparseTensorCOO indices") + + check_status(NdarraysToSparseTensorCOO(c_default_memory_pool(), + data, coords, c_shape, c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_tensor_coo(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseTensorCOO + """ + cdef shared_ptr[CSparseTensorCOO] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseTensorCOO(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_tensor_coo(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseTensorCOO to numpy.ndarrays with zero copy + """ + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseTensorCOOToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + return PyObject_to_object(out_data), PyObject_to_object(out_coords) + + def equals(self, SparseTensorCOO other): + """ + Return true if sparse tensors contains exactly equal data + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseTensorCOO): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return [frombytes(x) for x in tuple(self.stp.dim_names())] + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + + +cdef class SparseTensorCSR: + """ + A sparse CSR tensor. + """ + + def __init__(self): + raise TypeError("Do not call SparseTensorCSR's constructor directly, " + "use one of the `pyarrow.SparseTensorCSR.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseTensorCSR]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseTensorCSR + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, dim_names=None): + """ + Create arrow::SparseTensorCSR from numpy.ndarrays + """ + cdef shared_ptr[CSparseTensorCSR] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseTensorCSR indices + indptr = np.require(indptr, dtype='i8') + indices = np.require(indices, dtype='i8') + if indptr.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseTensorCSR indptr") + if indices.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseTensorCSR indices") + + check_status(NdarraysToSparseTensorCSR(c_default_memory_pool(), + data, indptr, indices, c_shape, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_tensor_csr(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseTensorCSR + """ + cdef shared_ptr[CSparseTensorCSR] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseTensorCSR(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_tensor_csr(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseTensorCSR to numpy.ndarrays with zero copy + """ + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseTensorCSRToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def equals(self, SparseTensorCSR other): + """ + Return true if sparse tensors contains exactly equal data + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseTensorCSR): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return [frombytes(x) for x in tuple(self.stp.dim_names())] + + @property + def non_zero_length(self): + return self.stp.non_zero_length() diff --git a/python/pyarrow/tests/test_sparse_tensor.py b/python/pyarrow/tests/test_sparse_tensor.py new file mode 100644 index 00000000000..68564dacf4b --- /dev/null +++ b/python/pyarrow/tests/test_sparse_tensor.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sys + +import numpy as np +import pyarrow as pa + + +tensor_type_pairs = [ + ('i1', pa.int8()), + ('i2', pa.int16()), + ('i4', pa.int32()), + ('i8', pa.int64()), + ('u1', pa.uint8()), + ('u2', pa.uint16()), + ('u4', pa.uint32()), + ('u8', pa.uint64()), + ('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64()) +] + + +@pytest.mark.parametrize('sparse_tensor_type', [ + pa.SparseTensorCSR, + pa.SparseTensorCOO, +]) +def test_sparse_tensor_attrs(sparse_tensor_type): + data = np.array([ + [0, 1, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + [0, 3, 0, 0, 0], + ]) + dim_names = ['x', 'y'] + sparse_tensor = sparse_tensor_type.from_dense_numpy(data, dim_names) + + assert sparse_tensor.ndim == 2 + assert sparse_tensor.size == 25 + assert sparse_tensor.shape == data.shape + assert sparse_tensor.is_mutable + assert sparse_tensor.dim_name(0) == dim_names[0] + assert sparse_tensor.dim_names == dim_names + assert sparse_tensor.non_zero_length == 4 + + +def test_sparse_tensor_coo_base_object(): + data = np.array([[4], [9], [7], [5]]) + coords = np.array([[0, 0], [0, 2], [1, 1], [3, 3]]) + array = np.array([[4, 0, 9, 0], + [0, 7, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 5]]) + sparse_tensor = pa.SparseTensorCOO.from_dense_numpy(array) + n = sys.getrefcount(sparse_tensor) + result_data, result_coords = sparse_tensor.to_numpy() + assert sys.getrefcount(sparse_tensor) == n + 2 + + sparse_tensor = None + assert np.array_equal(data, result_data) + assert np.array_equal(coords, result_coords) + assert result_coords.flags.f_contiguous # column-major + + +def test_sparse_tensor_csr_base_object(): + data = np.array([[1], [2], [3], [4], [5], [6]]) + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + array = np.array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]) + + sparse_tensor = pa.SparseTensorCSR.from_dense_numpy(array) + n = sys.getrefcount(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sys.getrefcount(sparse_tensor) == n + 3 + + sparse_tensor = None + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + +@pytest.mark.parametrize('sparse_tensor_type', [ + pa.SparseTensorCSR, + pa.SparseTensorCOO, +]) +def test_sparse_tensor_equals(sparse_tensor_type): + def eq(a, b): + assert a.equals(b) + assert a == b + assert not (a != b) + + def ne(a, b): + assert not a.equals(b) + assert not (a == b) + assert a != b + + data = np.random.randn(10, 6)[::, ::2] + sparse_tensor1 = sparse_tensor_type.from_dense_numpy(data) + sparse_tensor2 = sparse_tensor_type.from_dense_numpy( + np.ascontiguousarray(data)) + eq(sparse_tensor1, sparse_tensor2) + data = data.copy() + data[9, 0] = 1.0 + sparse_tensor2 = sparse_tensor_type.from_dense_numpy( + np.ascontiguousarray(data)) + ne(sparse_tensor1, sparse_tensor2) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_tensor_coo_from_dense(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[4], [9], [7], [5]]).astype(dtype) + coords = np.array([[0, 0], [0, 2], [1, 1], [3, 3]]) + array = np.array([[4, 0, 9, 0], + [0, 7, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 5]]).astype(dtype) + tensor = pa.Tensor.from_numpy(array) + + # Test from numpy array + sparse_tensor = pa.SparseTensorCOO.from_dense_numpy(array) + repr(sparse_tensor) + assert sparse_tensor.type == arrow_type + result_data, result_coords = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(coords, result_coords) + + # Test from Tensor + sparse_tensor = pa.SparseTensorCOO.from_tensor(tensor) + repr(sparse_tensor) + assert sparse_tensor.type == arrow_type + result_data, result_coords = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(coords, result_coords) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_tensor_csr_from_dense(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + dense_data = np.array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]).astype(dtype) + + data = np.array([[1], [2], [3], [4], [5], [6]]) + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + tensor = pa.Tensor.from_numpy(dense_data) + + # Test from numpy array + sparse_tensor = pa.SparseTensorCSR.from_dense_numpy(dense_data) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + # Test from Tensor + sparse_tensor = pa.SparseTensorCSR.from_tensor(tensor) + repr(sparse_tensor) + assert sparse_tensor.type == arrow_type + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_tensor_coo_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[4], [9], [7], [5]]).astype(dtype) + coords = np.array([[0, 0], [3, 3], [1, 1], [0, 2]]) + shape = (4, 4) + dim_names = ["x", "y"] + + sparse_tensor = pa.SparseTensorCOO.from_numpy(data, coords, shape, + dim_names) + repr(sparse_tensor) + assert sparse_tensor.type == arrow_type + result_data, result_coords = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(coords, result_coords) + assert sparse_tensor.dim_names == dim_names + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_tensor_csr_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[1], [2], [3], [4], [5], [6]]).astype(dtype) + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + shape = (3, 3) + dim_names = ["x", "y"] + + sparse_tensor = pa.SparseTensorCSR.from_numpy(data, indptr, indices, + shape, dim_names) + repr(sparse_tensor) + assert sparse_tensor.type == arrow_type + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + assert sparse_tensor.dim_names == dim_names diff --git a/python/pyarrow/tests/test_tensor.py b/python/pyarrow/tests/test_tensor.py index 188a4a5e1a5..13f05d27489 100644 --- a/python/pyarrow/tests/test_tensor.py +++ b/python/pyarrow/tests/test_tensor.py @@ -23,12 +23,28 @@ import pyarrow as pa +tensor_type_pairs = [ + ('i1', pa.int8()), + ('i2', pa.int16()), + ('i4', pa.int32()), + ('i8', pa.int64()), + ('u1', pa.uint8()), + ('u2', pa.uint16()), + ('u4', pa.uint32()), + ('u8', pa.uint64()), + ('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64()) +] + + def test_tensor_attrs(): data = np.random.randn(10, 4) tensor = pa.Tensor.from_numpy(data) assert tensor.ndim == 2 + assert tensor.dim_names == [] assert tensor.size == 40 assert tensor.shape == data.shape assert tensor.strides == data.strides @@ -42,6 +58,13 @@ def test_tensor_attrs(): tensor = pa.Tensor.from_numpy(data2) assert not tensor.is_mutable + # With dim_names + tensor = pa.Tensor.from_numpy(data, dim_names=('x', 'y')) + assert tensor.ndim == 2 + assert tensor.dim_names == ['x', 'y'] + assert tensor.dim_name(0) == 'x' + assert tensor.dim_name(1) == 'y' + def test_tensor_base_object(): tensor = pa.Tensor.from_numpy(np.random.randn(10, 4)) @@ -50,19 +73,7 @@ def test_tensor_base_object(): assert sys.getrefcount(tensor) == n + 1 -@pytest.mark.parametrize('dtype_str,arrow_type', [ - ('i1', pa.int8()), - ('i2', pa.int16()), - ('i4', pa.int32()), - ('i8', pa.int64()), - ('u1', pa.uint8()), - ('u2', pa.uint16()), - ('u4', pa.uint32()), - ('u8', pa.uint64()), - ('f2', pa.float16()), - ('f4', pa.float32()), - ('f8', pa.float64()) -]) +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) def test_tensor_numpy_roundtrip(dtype_str, arrow_type): dtype = np.dtype(dtype_str) data = (100 * np.random.randn(10, 4)).astype(dtype) @@ -76,15 +87,6 @@ def test_tensor_numpy_roundtrip(dtype_str, arrow_type): assert (data == result).all() -def _try_delete(path): - import gc - gc.collect() - try: - os.remove(path) - except os.error: - pass - - def test_tensor_ipc_roundtrip(tmpdir): data = np.random.randn(10, 4) tensor = pa.Tensor.from_numpy(data) From 6019dbc8d8defa7081c5cebc5afc81ad48cbbcd4 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 2 Jul 2019 11:08:59 +0200 Subject: [PATCH 27/52] ARROW-5731: [CI] Switch turbodbc branch for integration testing Author: Uwe L. Korn Closes #4751 from xhochy/ARROW-5731 and squashes the following commits: f4519afe5 ARROW-5731: Switch turbodbc branch for integration testing --- integration/turbodbc/runtest.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integration/turbodbc/runtest.sh b/integration/turbodbc/runtest.sh index 31f924336fe..874cba05434 100755 --- a/integration/turbodbc/runtest.sh +++ b/integration/turbodbc/runtest.sh @@ -25,9 +25,8 @@ python -c "import pyarrow.orc" python -c "import pyarrow.parquet" pushd /tmp -git clone https://github.com/xhochy/turbodbc.git +git clone https://github.com/blue-yonder/turbodbc.git pushd turbodbc -git checkout arrow-0.13.0-prep git submodule update --init --recursive service postgresql start From 4b09ae3d8b6a95fa52def5c884c0c1beb4862543 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 2 Jul 2019 11:21:51 +0200 Subject: [PATCH 28/52] ARROW-5753: [Rust] Fix test failure in CI code coverage Author: Chao Sun Closes #4748 from sunchao/ARROW-5753 and squashes the following commits: f66d0bf6b Remove warnings in rustfmt 4a22e6b5a ARROW-5753: Fix test failure in CI code coverage --- .travis.yml | 2 ++ rust/rustfmt.toml | 5 +---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 117f5dca683..793e86aa789 100644 --- a/.travis.yml +++ b/.travis.yml @@ -307,6 +307,8 @@ matrix: after_success: - pushd ${TRAVIS_BUILD_DIR}/rust # Run coverage for codecov.io + - export ARROW_TEST_DATA=$TRAVIS_BUILD_DIR/testing/data + - export PARQUET_TEST_DATA=$TRAVIS_BUILD_DIR/cpp/submodules/parquet-testing/data - cargo tarpaulin --out Xml - bash <(curl -s https://codecov.io/bash) || echo "Codecov did not collect coverage reports" - name: Go diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml index b692119bbc1..418b9e2acbb 100644 --- a/rust/rustfmt.toml +++ b/rust/rustfmt.toml @@ -15,7 +15,4 @@ # specific language governing permissions and limitations # under the License. -max_width = 90 -wrap_comments = true -format_doc_comments = true -comment_width = 90 \ No newline at end of file +max_width = 90 \ No newline at end of file From 1bbeb35b4d2eac8c87beb000e401edd15aeca03d Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Tue, 2 Jul 2019 11:32:48 +0200 Subject: [PATCH 29/52] ARROW-5380: [C++] Fix memory alignment UBSan errors. - Add utility methods for unaligned loads use where errors are discovered. - Upgrade version of flatbuffers to avoid issues with unaligned load in that library - Discover bug in spec that makes zero-copy well defined behavior virtually impossible with flatbuffers (need to discuss on ML). For now I'm not turning on ASAN and will file a follow-up JIRA to track this. Still needed: - [ ] Performance testing - [X] Discuss flatbuffers issues (I sent e-mail to LM) Author: Micah Kornfield Author: emkornfield Closes #4757 from emkornfield/ubsan_mem and squashes the following commits: 5528584a7 remove TODO db49fbbb4 Ubsan excluding flatbuffers --- cpp/src/arrow/util/bpacking.h | 3409 +++++++++++++++++------------- cpp/src/arrow/util/hashing.h | 9 +- cpp/src/arrow/util/ubsan.h | 16 + cpp/src/parquet/arrow/reader.cc | 20 +- cpp/src/parquet/arrow/writer.h | 5 +- cpp/src/parquet/column_reader.cc | 3 +- cpp/src/parquet/encoding.cc | 11 +- cpp/src/parquet/file_reader.cc | 3 +- cpp/src/plasma/common.cc | 4 +- 9 files changed, 2015 insertions(+), 1465 deletions(-) diff --git a/cpp/src/arrow/util/bpacking.h b/cpp/src/arrow/util/bpacking.h index 14258cff6e4..98c2e7deaee 100644 --- a/cpp/src/arrow/util/bpacking.h +++ b/cpp/src/arrow/util/bpacking.h @@ -28,74 +28,76 @@ #define ARROW_UTIL_BPACKING_H #include "arrow/util/logging.h" +#include "arrow/util/ubsan.h" namespace arrow { namespace internal { inline const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) & 1; + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) & 1; out++; - *out = ((*in) >> 1) & 1; + *out = (inl >> 1) & 1; out++; - *out = ((*in) >> 2) & 1; + *out = (inl >> 2) & 1; out++; - *out = ((*in) >> 3) & 1; + *out = (inl >> 3) & 1; out++; - *out = ((*in) >> 4) & 1; + *out = (inl >> 4) & 1; out++; - *out = ((*in) >> 5) & 1; + *out = (inl >> 5) & 1; out++; - *out = ((*in) >> 6) & 1; + *out = (inl >> 6) & 1; out++; - *out = ((*in) >> 7) & 1; + *out = (inl >> 7) & 1; out++; - *out = ((*in) >> 8) & 1; + *out = (inl >> 8) & 1; out++; - *out = ((*in) >> 9) & 1; + *out = (inl >> 9) & 1; out++; - *out = ((*in) >> 10) & 1; + *out = (inl >> 10) & 1; out++; - *out = ((*in) >> 11) & 1; + *out = (inl >> 11) & 1; out++; - *out = ((*in) >> 12) & 1; + *out = (inl >> 12) & 1; out++; - *out = ((*in) >> 13) & 1; + *out = (inl >> 13) & 1; out++; - *out = ((*in) >> 14) & 1; + *out = (inl >> 14) & 1; out++; - *out = ((*in) >> 15) & 1; + *out = (inl >> 15) & 1; out++; - *out = ((*in) >> 16) & 1; + *out = (inl >> 16) & 1; out++; - *out = ((*in) >> 17) & 1; + *out = (inl >> 17) & 1; out++; - *out = ((*in) >> 18) & 1; + *out = (inl >> 18) & 1; out++; - *out = ((*in) >> 19) & 1; + *out = (inl >> 19) & 1; out++; - *out = ((*in) >> 20) & 1; + *out = (inl >> 20) & 1; out++; - *out = ((*in) >> 21) & 1; + *out = (inl >> 21) & 1; out++; - *out = ((*in) >> 22) & 1; + *out = (inl >> 22) & 1; out++; - *out = ((*in) >> 23) & 1; + *out = (inl >> 23) & 1; out++; - *out = ((*in) >> 24) & 1; + *out = (inl >> 24) & 1; out++; - *out = ((*in) >> 25) & 1; + *out = (inl >> 25) & 1; out++; - *out = ((*in) >> 26) & 1; + *out = (inl >> 26) & 1; out++; - *out = ((*in) >> 27) & 1; + *out = (inl >> 27) & 1; out++; - *out = ((*in) >> 28) & 1; + *out = (inl >> 28) & 1; out++; - *out = ((*in) >> 29) & 1; + *out = (inl >> 29) & 1; out++; - *out = ((*in) >> 30) & 1; + *out = (inl >> 30) & 1; out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; out++; @@ -103,70 +105,72 @@ inline const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 2); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 2); out++; - *out = ((*in) >> 2) % (1U << 2); + *out = (inl >> 2) % (1U << 2); out++; - *out = ((*in) >> 4) % (1U << 2); + *out = (inl >> 4) % (1U << 2); out++; - *out = ((*in) >> 6) % (1U << 2); + *out = (inl >> 6) % (1U << 2); out++; - *out = ((*in) >> 8) % (1U << 2); + *out = (inl >> 8) % (1U << 2); out++; - *out = ((*in) >> 10) % (1U << 2); + *out = (inl >> 10) % (1U << 2); out++; - *out = ((*in) >> 12) % (1U << 2); + *out = (inl >> 12) % (1U << 2); out++; - *out = ((*in) >> 14) % (1U << 2); + *out = (inl >> 14) % (1U << 2); out++; - *out = ((*in) >> 16) % (1U << 2); + *out = (inl >> 16) % (1U << 2); out++; - *out = ((*in) >> 18) % (1U << 2); + *out = (inl >> 18) % (1U << 2); out++; - *out = ((*in) >> 20) % (1U << 2); + *out = (inl >> 20) % (1U << 2); out++; - *out = ((*in) >> 22) % (1U << 2); + *out = (inl >> 22) % (1U << 2); out++; - *out = ((*in) >> 24) % (1U << 2); + *out = (inl >> 24) % (1U << 2); out++; - *out = ((*in) >> 26) % (1U << 2); + *out = (inl >> 26) % (1U << 2); out++; - *out = ((*in) >> 28) % (1U << 2); + *out = (inl >> 28) % (1U << 2); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 2); + *out = (inl >> 0) % (1U << 2); out++; - *out = ((*in) >> 2) % (1U << 2); + *out = (inl >> 2) % (1U << 2); out++; - *out = ((*in) >> 4) % (1U << 2); + *out = (inl >> 4) % (1U << 2); out++; - *out = ((*in) >> 6) % (1U << 2); + *out = (inl >> 6) % (1U << 2); out++; - *out = ((*in) >> 8) % (1U << 2); + *out = (inl >> 8) % (1U << 2); out++; - *out = ((*in) >> 10) % (1U << 2); + *out = (inl >> 10) % (1U << 2); out++; - *out = ((*in) >> 12) % (1U << 2); + *out = (inl >> 12) % (1U << 2); out++; - *out = ((*in) >> 14) % (1U << 2); + *out = (inl >> 14) % (1U << 2); out++; - *out = ((*in) >> 16) % (1U << 2); + *out = (inl >> 16) % (1U << 2); out++; - *out = ((*in) >> 18) % (1U << 2); + *out = (inl >> 18) % (1U << 2); out++; - *out = ((*in) >> 20) % (1U << 2); + *out = (inl >> 20) % (1U << 2); out++; - *out = ((*in) >> 22) % (1U << 2); + *out = (inl >> 22) % (1U << 2); out++; - *out = ((*in) >> 24) % (1U << 2); + *out = (inl >> 24) % (1U << 2); out++; - *out = ((*in) >> 26) % (1U << 2); + *out = (inl >> 26) % (1U << 2); out++; - *out = ((*in) >> 28) % (1U << 2); + *out = (inl >> 28) % (1U << 2); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; out++; @@ -174,73 +178,76 @@ inline const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 3); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 3); out++; - *out = ((*in) >> 3) % (1U << 3); + *out = (inl >> 3) % (1U << 3); out++; - *out = ((*in) >> 6) % (1U << 3); + *out = (inl >> 6) % (1U << 3); out++; - *out = ((*in) >> 9) % (1U << 3); + *out = (inl >> 9) % (1U << 3); out++; - *out = ((*in) >> 12) % (1U << 3); + *out = (inl >> 12) % (1U << 3); out++; - *out = ((*in) >> 15) % (1U << 3); + *out = (inl >> 15) % (1U << 3); out++; - *out = ((*in) >> 18) % (1U << 3); + *out = (inl >> 18) % (1U << 3); out++; - *out = ((*in) >> 21) % (1U << 3); + *out = (inl >> 21) % (1U << 3); out++; - *out = ((*in) >> 24) % (1U << 3); + *out = (inl >> 24) % (1U << 3); out++; - *out = ((*in) >> 27) % (1U << 3); + *out = (inl >> 27) % (1U << 3); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 1)) << (3 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (3 - 1); out++; - *out = ((*in) >> 1) % (1U << 3); + *out = (inl >> 1) % (1U << 3); out++; - *out = ((*in) >> 4) % (1U << 3); + *out = (inl >> 4) % (1U << 3); out++; - *out = ((*in) >> 7) % (1U << 3); + *out = (inl >> 7) % (1U << 3); out++; - *out = ((*in) >> 10) % (1U << 3); + *out = (inl >> 10) % (1U << 3); out++; - *out = ((*in) >> 13) % (1U << 3); + *out = (inl >> 13) % (1U << 3); out++; - *out = ((*in) >> 16) % (1U << 3); + *out = (inl >> 16) % (1U << 3); out++; - *out = ((*in) >> 19) % (1U << 3); + *out = (inl >> 19) % (1U << 3); out++; - *out = ((*in) >> 22) % (1U << 3); + *out = (inl >> 22) % (1U << 3); out++; - *out = ((*in) >> 25) % (1U << 3); + *out = (inl >> 25) % (1U << 3); out++; - *out = ((*in) >> 28) % (1U << 3); + *out = (inl >> 28) % (1U << 3); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 2)) << (3 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (3 - 2); out++; - *out = ((*in) >> 2) % (1U << 3); + *out = (inl >> 2) % (1U << 3); out++; - *out = ((*in) >> 5) % (1U << 3); + *out = (inl >> 5) % (1U << 3); out++; - *out = ((*in) >> 8) % (1U << 3); + *out = (inl >> 8) % (1U << 3); out++; - *out = ((*in) >> 11) % (1U << 3); + *out = (inl >> 11) % (1U << 3); out++; - *out = ((*in) >> 14) % (1U << 3); + *out = (inl >> 14) % (1U << 3); out++; - *out = ((*in) >> 17) % (1U << 3); + *out = (inl >> 17) % (1U << 3); out++; - *out = ((*in) >> 20) % (1U << 3); + *out = (inl >> 20) % (1U << 3); out++; - *out = ((*in) >> 23) % (1U << 3); + *out = (inl >> 23) % (1U << 3); out++; - *out = ((*in) >> 26) % (1U << 3); + *out = (inl >> 26) % (1U << 3); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; out++; @@ -248,72 +255,76 @@ inline const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 4); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 4); out++; - *out = ((*in) >> 4) % (1U << 4); + *out = (inl >> 4) % (1U << 4); out++; - *out = ((*in) >> 8) % (1U << 4); + *out = (inl >> 8) % (1U << 4); out++; - *out = ((*in) >> 12) % (1U << 4); + *out = (inl >> 12) % (1U << 4); out++; - *out = ((*in) >> 16) % (1U << 4); + *out = (inl >> 16) % (1U << 4); out++; - *out = ((*in) >> 20) % (1U << 4); + *out = (inl >> 20) % (1U << 4); out++; - *out = ((*in) >> 24) % (1U << 4); + *out = (inl >> 24) % (1U << 4); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 4); + *out = (inl >> 0) % (1U << 4); out++; - *out = ((*in) >> 4) % (1U << 4); + *out = (inl >> 4) % (1U << 4); out++; - *out = ((*in) >> 8) % (1U << 4); + *out = (inl >> 8) % (1U << 4); out++; - *out = ((*in) >> 12) % (1U << 4); + *out = (inl >> 12) % (1U << 4); out++; - *out = ((*in) >> 16) % (1U << 4); + *out = (inl >> 16) % (1U << 4); out++; - *out = ((*in) >> 20) % (1U << 4); + *out = (inl >> 20) % (1U << 4); out++; - *out = ((*in) >> 24) % (1U << 4); + *out = (inl >> 24) % (1U << 4); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 4); + *out = (inl >> 0) % (1U << 4); out++; - *out = ((*in) >> 4) % (1U << 4); + *out = (inl >> 4) % (1U << 4); out++; - *out = ((*in) >> 8) % (1U << 4); + *out = (inl >> 8) % (1U << 4); out++; - *out = ((*in) >> 12) % (1U << 4); + *out = (inl >> 12) % (1U << 4); out++; - *out = ((*in) >> 16) % (1U << 4); + *out = (inl >> 16) % (1U << 4); out++; - *out = ((*in) >> 20) % (1U << 4); + *out = (inl >> 20) % (1U << 4); out++; - *out = ((*in) >> 24) % (1U << 4); + *out = (inl >> 24) % (1U << 4); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 4); + *out = (inl >> 0) % (1U << 4); out++; - *out = ((*in) >> 4) % (1U << 4); + *out = (inl >> 4) % (1U << 4); out++; - *out = ((*in) >> 8) % (1U << 4); + *out = (inl >> 8) % (1U << 4); out++; - *out = ((*in) >> 12) % (1U << 4); + *out = (inl >> 12) % (1U << 4); out++; - *out = ((*in) >> 16) % (1U << 4); + *out = (inl >> 16) % (1U << 4); out++; - *out = ((*in) >> 20) % (1U << 4); + *out = (inl >> 20) % (1U << 4); out++; - *out = ((*in) >> 24) % (1U << 4); + *out = (inl >> 24) % (1U << 4); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; out++; @@ -321,77 +332,82 @@ inline const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 5); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 5); out++; - *out = ((*in) >> 5) % (1U << 5); + *out = (inl >> 5) % (1U << 5); out++; - *out = ((*in) >> 10) % (1U << 5); + *out = (inl >> 10) % (1U << 5); out++; - *out = ((*in) >> 15) % (1U << 5); + *out = (inl >> 15) % (1U << 5); out++; - *out = ((*in) >> 20) % (1U << 5); + *out = (inl >> 20) % (1U << 5); out++; - *out = ((*in) >> 25) % (1U << 5); + *out = (inl >> 25) % (1U << 5); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 3)) << (5 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (5 - 3); out++; - *out = ((*in) >> 3) % (1U << 5); + *out = (inl >> 3) % (1U << 5); out++; - *out = ((*in) >> 8) % (1U << 5); + *out = (inl >> 8) % (1U << 5); out++; - *out = ((*in) >> 13) % (1U << 5); + *out = (inl >> 13) % (1U << 5); out++; - *out = ((*in) >> 18) % (1U << 5); + *out = (inl >> 18) % (1U << 5); out++; - *out = ((*in) >> 23) % (1U << 5); + *out = (inl >> 23) % (1U << 5); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 1)) << (5 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (5 - 1); out++; - *out = ((*in) >> 1) % (1U << 5); + *out = (inl >> 1) % (1U << 5); out++; - *out = ((*in) >> 6) % (1U << 5); + *out = (inl >> 6) % (1U << 5); out++; - *out = ((*in) >> 11) % (1U << 5); + *out = (inl >> 11) % (1U << 5); out++; - *out = ((*in) >> 16) % (1U << 5); + *out = (inl >> 16) % (1U << 5); out++; - *out = ((*in) >> 21) % (1U << 5); + *out = (inl >> 21) % (1U << 5); out++; - *out = ((*in) >> 26) % (1U << 5); + *out = (inl >> 26) % (1U << 5); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 4)) << (5 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (5 - 4); out++; - *out = ((*in) >> 4) % (1U << 5); + *out = (inl >> 4) % (1U << 5); out++; - *out = ((*in) >> 9) % (1U << 5); + *out = (inl >> 9) % (1U << 5); out++; - *out = ((*in) >> 14) % (1U << 5); + *out = (inl >> 14) % (1U << 5); out++; - *out = ((*in) >> 19) % (1U << 5); + *out = (inl >> 19) % (1U << 5); out++; - *out = ((*in) >> 24) % (1U << 5); + *out = (inl >> 24) % (1U << 5); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 2)) << (5 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (5 - 2); out++; - *out = ((*in) >> 2) % (1U << 5); + *out = (inl >> 2) % (1U << 5); out++; - *out = ((*in) >> 7) % (1U << 5); + *out = (inl >> 7) % (1U << 5); out++; - *out = ((*in) >> 12) % (1U << 5); + *out = (inl >> 12) % (1U << 5); out++; - *out = ((*in) >> 17) % (1U << 5); + *out = (inl >> 17) % (1U << 5); out++; - *out = ((*in) >> 22) % (1U << 5); + *out = (inl >> 22) % (1U << 5); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; out++; @@ -399,78 +415,84 @@ inline const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 6); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 6); out++; - *out = ((*in) >> 6) % (1U << 6); + *out = (inl >> 6) % (1U << 6); out++; - *out = ((*in) >> 12) % (1U << 6); + *out = (inl >> 12) % (1U << 6); out++; - *out = ((*in) >> 18) % (1U << 6); + *out = (inl >> 18) % (1U << 6); out++; - *out = ((*in) >> 24) % (1U << 6); + *out = (inl >> 24) % (1U << 6); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 4)) << (6 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (6 - 4); out++; - *out = ((*in) >> 4) % (1U << 6); + *out = (inl >> 4) % (1U << 6); out++; - *out = ((*in) >> 10) % (1U << 6); + *out = (inl >> 10) % (1U << 6); out++; - *out = ((*in) >> 16) % (1U << 6); + *out = (inl >> 16) % (1U << 6); out++; - *out = ((*in) >> 22) % (1U << 6); + *out = (inl >> 22) % (1U << 6); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 2)) << (6 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (6 - 2); out++; - *out = ((*in) >> 2) % (1U << 6); + *out = (inl >> 2) % (1U << 6); out++; - *out = ((*in) >> 8) % (1U << 6); + *out = (inl >> 8) % (1U << 6); out++; - *out = ((*in) >> 14) % (1U << 6); + *out = (inl >> 14) % (1U << 6); out++; - *out = ((*in) >> 20) % (1U << 6); + *out = (inl >> 20) % (1U << 6); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 6); + *out = (inl >> 0) % (1U << 6); out++; - *out = ((*in) >> 6) % (1U << 6); + *out = (inl >> 6) % (1U << 6); out++; - *out = ((*in) >> 12) % (1U << 6); + *out = (inl >> 12) % (1U << 6); out++; - *out = ((*in) >> 18) % (1U << 6); + *out = (inl >> 18) % (1U << 6); out++; - *out = ((*in) >> 24) % (1U << 6); + *out = (inl >> 24) % (1U << 6); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 4)) << (6 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (6 - 4); out++; - *out = ((*in) >> 4) % (1U << 6); + *out = (inl >> 4) % (1U << 6); out++; - *out = ((*in) >> 10) % (1U << 6); + *out = (inl >> 10) % (1U << 6); out++; - *out = ((*in) >> 16) % (1U << 6); + *out = (inl >> 16) % (1U << 6); out++; - *out = ((*in) >> 22) % (1U << 6); + *out = (inl >> 22) % (1U << 6); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 2)) << (6 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (6 - 2); out++; - *out = ((*in) >> 2) % (1U << 6); + *out = (inl >> 2) % (1U << 6); out++; - *out = ((*in) >> 8) % (1U << 6); + *out = (inl >> 8) % (1U << 6); out++; - *out = ((*in) >> 14) % (1U << 6); + *out = (inl >> 14) % (1U << 6); out++; - *out = ((*in) >> 20) % (1U << 6); + *out = (inl >> 20) % (1U << 6); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; out++; @@ -478,81 +500,88 @@ inline const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 7); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 7); out++; - *out = ((*in) >> 7) % (1U << 7); + *out = (inl >> 7) % (1U << 7); out++; - *out = ((*in) >> 14) % (1U << 7); + *out = (inl >> 14) % (1U << 7); out++; - *out = ((*in) >> 21) % (1U << 7); + *out = (inl >> 21) % (1U << 7); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 3)) << (7 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (7 - 3); out++; - *out = ((*in) >> 3) % (1U << 7); + *out = (inl >> 3) % (1U << 7); out++; - *out = ((*in) >> 10) % (1U << 7); + *out = (inl >> 10) % (1U << 7); out++; - *out = ((*in) >> 17) % (1U << 7); + *out = (inl >> 17) % (1U << 7); out++; - *out = ((*in) >> 24) % (1U << 7); + *out = (inl >> 24) % (1U << 7); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 6)) << (7 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (7 - 6); out++; - *out = ((*in) >> 6) % (1U << 7); + *out = (inl >> 6) % (1U << 7); out++; - *out = ((*in) >> 13) % (1U << 7); + *out = (inl >> 13) % (1U << 7); out++; - *out = ((*in) >> 20) % (1U << 7); + *out = (inl >> 20) % (1U << 7); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 2)) << (7 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (7 - 2); out++; - *out = ((*in) >> 2) % (1U << 7); + *out = (inl >> 2) % (1U << 7); out++; - *out = ((*in) >> 9) % (1U << 7); + *out = (inl >> 9) % (1U << 7); out++; - *out = ((*in) >> 16) % (1U << 7); + *out = (inl >> 16) % (1U << 7); out++; - *out = ((*in) >> 23) % (1U << 7); + *out = (inl >> 23) % (1U << 7); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 5)) << (7 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (7 - 5); out++; - *out = ((*in) >> 5) % (1U << 7); + *out = (inl >> 5) % (1U << 7); out++; - *out = ((*in) >> 12) % (1U << 7); + *out = (inl >> 12) % (1U << 7); out++; - *out = ((*in) >> 19) % (1U << 7); + *out = (inl >> 19) % (1U << 7); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 1)) << (7 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (7 - 1); out++; - *out = ((*in) >> 1) % (1U << 7); + *out = (inl >> 1) % (1U << 7); out++; - *out = ((*in) >> 8) % (1U << 7); + *out = (inl >> 8) % (1U << 7); out++; - *out = ((*in) >> 15) % (1U << 7); + *out = (inl >> 15) % (1U << 7); out++; - *out = ((*in) >> 22) % (1U << 7); + *out = (inl >> 22) % (1U << 7); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 4)) << (7 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (7 - 4); out++; - *out = ((*in) >> 4) % (1U << 7); + *out = (inl >> 4) % (1U << 7); out++; - *out = ((*in) >> 11) % (1U << 7); + *out = (inl >> 11) % (1U << 7); out++; - *out = ((*in) >> 18) % (1U << 7); + *out = (inl >> 18) % (1U << 7); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; out++; @@ -560,76 +589,84 @@ inline const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 8); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 8); + *out = (inl >> 0) % (1U << 8); out++; - *out = ((*in) >> 8) % (1U << 8); + *out = (inl >> 8) % (1U << 8); out++; - *out = ((*in) >> 16) % (1U << 8); + *out = (inl >> 16) % (1U << 8); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; out++; @@ -637,85 +674,94 @@ inline const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 9); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 9); out++; - *out = ((*in) >> 9) % (1U << 9); + *out = (inl >> 9) % (1U << 9); out++; - *out = ((*in) >> 18) % (1U << 9); + *out = (inl >> 18) % (1U << 9); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 4)) << (9 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (9 - 4); out++; - *out = ((*in) >> 4) % (1U << 9); + *out = (inl >> 4) % (1U << 9); out++; - *out = ((*in) >> 13) % (1U << 9); + *out = (inl >> 13) % (1U << 9); out++; - *out = ((*in) >> 22) % (1U << 9); + *out = (inl >> 22) % (1U << 9); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 8)) << (9 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (9 - 8); out++; - *out = ((*in) >> 8) % (1U << 9); + *out = (inl >> 8) % (1U << 9); out++; - *out = ((*in) >> 17) % (1U << 9); + *out = (inl >> 17) % (1U << 9); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 3)) << (9 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (9 - 3); out++; - *out = ((*in) >> 3) % (1U << 9); + *out = (inl >> 3) % (1U << 9); out++; - *out = ((*in) >> 12) % (1U << 9); + *out = (inl >> 12) % (1U << 9); out++; - *out = ((*in) >> 21) % (1U << 9); + *out = (inl >> 21) % (1U << 9); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 7)) << (9 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (9 - 7); out++; - *out = ((*in) >> 7) % (1U << 9); + *out = (inl >> 7) % (1U << 9); out++; - *out = ((*in) >> 16) % (1U << 9); + *out = (inl >> 16) % (1U << 9); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 2)) << (9 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (9 - 2); out++; - *out = ((*in) >> 2) % (1U << 9); + *out = (inl >> 2) % (1U << 9); out++; - *out = ((*in) >> 11) % (1U << 9); + *out = (inl >> 11) % (1U << 9); out++; - *out = ((*in) >> 20) % (1U << 9); + *out = (inl >> 20) % (1U << 9); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 6)) << (9 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (9 - 6); out++; - *out = ((*in) >> 6) % (1U << 9); + *out = (inl >> 6) % (1U << 9); out++; - *out = ((*in) >> 15) % (1U << 9); + *out = (inl >> 15) % (1U << 9); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 1)) << (9 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (9 - 1); out++; - *out = ((*in) >> 1) % (1U << 9); + *out = (inl >> 1) % (1U << 9); out++; - *out = ((*in) >> 10) % (1U << 9); + *out = (inl >> 10) % (1U << 9); out++; - *out = ((*in) >> 19) % (1U << 9); + *out = (inl >> 19) % (1U << 9); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 5)) << (9 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (9 - 5); out++; - *out = ((*in) >> 5) % (1U << 9); + *out = (inl >> 5) % (1U << 9); out++; - *out = ((*in) >> 14) % (1U << 9); + *out = (inl >> 14) % (1U << 9); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; out++; @@ -723,86 +769,96 @@ inline const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 10); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 10); out++; - *out = ((*in) >> 10) % (1U << 10); + *out = (inl >> 10) % (1U << 10); out++; - *out = ((*in) >> 20) % (1U << 10); + *out = (inl >> 20) % (1U << 10); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 8)) << (10 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (10 - 8); out++; - *out = ((*in) >> 8) % (1U << 10); + *out = (inl >> 8) % (1U << 10); out++; - *out = ((*in) >> 18) % (1U << 10); + *out = (inl >> 18) % (1U << 10); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 6)) << (10 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (10 - 6); out++; - *out = ((*in) >> 6) % (1U << 10); + *out = (inl >> 6) % (1U << 10); out++; - *out = ((*in) >> 16) % (1U << 10); + *out = (inl >> 16) % (1U << 10); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 4)) << (10 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (10 - 4); out++; - *out = ((*in) >> 4) % (1U << 10); + *out = (inl >> 4) % (1U << 10); out++; - *out = ((*in) >> 14) % (1U << 10); + *out = (inl >> 14) % (1U << 10); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 2)) << (10 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (10 - 2); out++; - *out = ((*in) >> 2) % (1U << 10); + *out = (inl >> 2) % (1U << 10); out++; - *out = ((*in) >> 12) % (1U << 10); + *out = (inl >> 12) % (1U << 10); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 10); + *out = (inl >> 0) % (1U << 10); out++; - *out = ((*in) >> 10) % (1U << 10); + *out = (inl >> 10) % (1U << 10); out++; - *out = ((*in) >> 20) % (1U << 10); + *out = (inl >> 20) % (1U << 10); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 8)) << (10 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (10 - 8); out++; - *out = ((*in) >> 8) % (1U << 10); + *out = (inl >> 8) % (1U << 10); out++; - *out = ((*in) >> 18) % (1U << 10); + *out = (inl >> 18) % (1U << 10); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 6)) << (10 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (10 - 6); out++; - *out = ((*in) >> 6) % (1U << 10); + *out = (inl >> 6) % (1U << 10); out++; - *out = ((*in) >> 16) % (1U << 10); + *out = (inl >> 16) % (1U << 10); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 4)) << (10 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (10 - 4); out++; - *out = ((*in) >> 4) % (1U << 10); + *out = (inl >> 4) % (1U << 10); out++; - *out = ((*in) >> 14) % (1U << 10); + *out = (inl >> 14) % (1U << 10); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 2)) << (10 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (10 - 2); out++; - *out = ((*in) >> 2) % (1U << 10); + *out = (inl >> 2) % (1U << 10); out++; - *out = ((*in) >> 12) % (1U << 10); + *out = (inl >> 12) % (1U << 10); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; out++; @@ -810,89 +866,100 @@ inline const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 11); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 11); out++; - *out = ((*in) >> 11) % (1U << 11); + *out = (inl >> 11) % (1U << 11); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 1)) << (11 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (11 - 1); out++; - *out = ((*in) >> 1) % (1U << 11); + *out = (inl >> 1) % (1U << 11); out++; - *out = ((*in) >> 12) % (1U << 11); + *out = (inl >> 12) % (1U << 11); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 2)) << (11 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (11 - 2); out++; - *out = ((*in) >> 2) % (1U << 11); + *out = (inl >> 2) % (1U << 11); out++; - *out = ((*in) >> 13) % (1U << 11); + *out = (inl >> 13) % (1U << 11); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 3)) << (11 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (11 - 3); out++; - *out = ((*in) >> 3) % (1U << 11); + *out = (inl >> 3) % (1U << 11); out++; - *out = ((*in) >> 14) % (1U << 11); + *out = (inl >> 14) % (1U << 11); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 4)) << (11 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (11 - 4); out++; - *out = ((*in) >> 4) % (1U << 11); + *out = (inl >> 4) % (1U << 11); out++; - *out = ((*in) >> 15) % (1U << 11); + *out = (inl >> 15) % (1U << 11); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 5)) << (11 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (11 - 5); out++; - *out = ((*in) >> 5) % (1U << 11); + *out = (inl >> 5) % (1U << 11); out++; - *out = ((*in) >> 16) % (1U << 11); + *out = (inl >> 16) % (1U << 11); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 6)) << (11 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (11 - 6); out++; - *out = ((*in) >> 6) % (1U << 11); + *out = (inl >> 6) % (1U << 11); out++; - *out = ((*in) >> 17) % (1U << 11); + *out = (inl >> 17) % (1U << 11); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 7)) << (11 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (11 - 7); out++; - *out = ((*in) >> 7) % (1U << 11); + *out = (inl >> 7) % (1U << 11); out++; - *out = ((*in) >> 18) % (1U << 11); + *out = (inl >> 18) % (1U << 11); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 8)) << (11 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (11 - 8); out++; - *out = ((*in) >> 8) % (1U << 11); + *out = (inl >> 8) % (1U << 11); out++; - *out = ((*in) >> 19) % (1U << 11); + *out = (inl >> 19) % (1U << 11); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 9)) << (11 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (11 - 9); out++; - *out = ((*in) >> 9) % (1U << 11); + *out = (inl >> 9) % (1U << 11); out++; - *out = ((*in) >> 20) % (1U << 11); + *out = (inl >> 20) % (1U << 11); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 10)) << (11 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (11 - 10); out++; - *out = ((*in) >> 10) % (1U << 11); + *out = (inl >> 10) % (1U << 11); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; out++; @@ -900,88 +967,100 @@ inline const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 12); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 12); out++; - *out = ((*in) >> 12) % (1U << 12); + *out = (inl >> 12) % (1U << 12); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 4)) << (12 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (12 - 4); out++; - *out = ((*in) >> 4) % (1U << 12); + *out = (inl >> 4) % (1U << 12); out++; - *out = ((*in) >> 16) % (1U << 12); + *out = (inl >> 16) % (1U << 12); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 8)) << (12 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (12 - 8); out++; - *out = ((*in) >> 8) % (1U << 12); + *out = (inl >> 8) % (1U << 12); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 12); + *out = (inl >> 0) % (1U << 12); out++; - *out = ((*in) >> 12) % (1U << 12); + *out = (inl >> 12) % (1U << 12); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 4)) << (12 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (12 - 4); out++; - *out = ((*in) >> 4) % (1U << 12); + *out = (inl >> 4) % (1U << 12); out++; - *out = ((*in) >> 16) % (1U << 12); + *out = (inl >> 16) % (1U << 12); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 8)) << (12 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (12 - 8); out++; - *out = ((*in) >> 8) % (1U << 12); + *out = (inl >> 8) % (1U << 12); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 12); + *out = (inl >> 0) % (1U << 12); out++; - *out = ((*in) >> 12) % (1U << 12); + *out = (inl >> 12) % (1U << 12); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 4)) << (12 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (12 - 4); out++; - *out = ((*in) >> 4) % (1U << 12); + *out = (inl >> 4) % (1U << 12); out++; - *out = ((*in) >> 16) % (1U << 12); + *out = (inl >> 16) % (1U << 12); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 8)) << (12 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (12 - 8); out++; - *out = ((*in) >> 8) % (1U << 12); + *out = (inl >> 8) % (1U << 12); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 12); + *out = (inl >> 0) % (1U << 12); out++; - *out = ((*in) >> 12) % (1U << 12); + *out = (inl >> 12) % (1U << 12); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 4)) << (12 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (12 - 4); out++; - *out = ((*in) >> 4) % (1U << 12); + *out = (inl >> 4) % (1U << 12); out++; - *out = ((*in) >> 16) % (1U << 12); + *out = (inl >> 16) % (1U << 12); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 8)) << (12 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (12 - 8); out++; - *out = ((*in) >> 8) % (1U << 12); + *out = (inl >> 8) % (1U << 12); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; out++; @@ -989,93 +1068,106 @@ inline const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 13); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 13); out++; - *out = ((*in) >> 13) % (1U << 13); + *out = (inl >> 13) % (1U << 13); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 7)) << (13 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (13 - 7); out++; - *out = ((*in) >> 7) % (1U << 13); + *out = (inl >> 7) % (1U << 13); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 1)) << (13 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (13 - 1); out++; - *out = ((*in) >> 1) % (1U << 13); + *out = (inl >> 1) % (1U << 13); out++; - *out = ((*in) >> 14) % (1U << 13); + *out = (inl >> 14) % (1U << 13); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 8)) << (13 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (13 - 8); out++; - *out = ((*in) >> 8) % (1U << 13); + *out = (inl >> 8) % (1U << 13); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 2)) << (13 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (13 - 2); out++; - *out = ((*in) >> 2) % (1U << 13); + *out = (inl >> 2) % (1U << 13); out++; - *out = ((*in) >> 15) % (1U << 13); + *out = (inl >> 15) % (1U << 13); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 9)) << (13 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (13 - 9); out++; - *out = ((*in) >> 9) % (1U << 13); + *out = (inl >> 9) % (1U << 13); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 3)) << (13 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (13 - 3); out++; - *out = ((*in) >> 3) % (1U << 13); + *out = (inl >> 3) % (1U << 13); out++; - *out = ((*in) >> 16) % (1U << 13); + *out = (inl >> 16) % (1U << 13); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 10)) << (13 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (13 - 10); out++; - *out = ((*in) >> 10) % (1U << 13); + *out = (inl >> 10) % (1U << 13); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 4)) << (13 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (13 - 4); out++; - *out = ((*in) >> 4) % (1U << 13); + *out = (inl >> 4) % (1U << 13); out++; - *out = ((*in) >> 17) % (1U << 13); + *out = (inl >> 17) % (1U << 13); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 11)) << (13 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (13 - 11); out++; - *out = ((*in) >> 11) % (1U << 13); + *out = (inl >> 11) % (1U << 13); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 5)) << (13 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (13 - 5); out++; - *out = ((*in) >> 5) % (1U << 13); + *out = (inl >> 5) % (1U << 13); out++; - *out = ((*in) >> 18) % (1U << 13); + *out = (inl >> 18) % (1U << 13); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 12)) << (13 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (13 - 12); out++; - *out = ((*in) >> 12) % (1U << 13); + *out = (inl >> 12) % (1U << 13); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 6)) << (13 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (13 - 6); out++; - *out = ((*in) >> 6) % (1U << 13); + *out = (inl >> 6) % (1U << 13); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; out++; @@ -1083,94 +1175,108 @@ inline const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 14); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 14); out++; - *out = ((*in) >> 14) % (1U << 14); + *out = (inl >> 14) % (1U << 14); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 10)) << (14 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (14 - 10); out++; - *out = ((*in) >> 10) % (1U << 14); + *out = (inl >> 10) % (1U << 14); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 6)) << (14 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (14 - 6); out++; - *out = ((*in) >> 6) % (1U << 14); + *out = (inl >> 6) % (1U << 14); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 2)) << (14 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (14 - 2); out++; - *out = ((*in) >> 2) % (1U << 14); + *out = (inl >> 2) % (1U << 14); out++; - *out = ((*in) >> 16) % (1U << 14); + *out = (inl >> 16) % (1U << 14); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 12)) << (14 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (14 - 12); out++; - *out = ((*in) >> 12) % (1U << 14); + *out = (inl >> 12) % (1U << 14); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 8)) << (14 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (14 - 8); out++; - *out = ((*in) >> 8) % (1U << 14); + *out = (inl >> 8) % (1U << 14); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 4)) << (14 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (14 - 4); out++; - *out = ((*in) >> 4) % (1U << 14); + *out = (inl >> 4) % (1U << 14); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 14); + *out = (inl >> 0) % (1U << 14); out++; - *out = ((*in) >> 14) % (1U << 14); + *out = (inl >> 14) % (1U << 14); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 10)) << (14 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (14 - 10); out++; - *out = ((*in) >> 10) % (1U << 14); + *out = (inl >> 10) % (1U << 14); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 6)) << (14 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (14 - 6); out++; - *out = ((*in) >> 6) % (1U << 14); + *out = (inl >> 6) % (1U << 14); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 2)) << (14 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (14 - 2); out++; - *out = ((*in) >> 2) % (1U << 14); + *out = (inl >> 2) % (1U << 14); out++; - *out = ((*in) >> 16) % (1U << 14); + *out = (inl >> 16) % (1U << 14); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 12)) << (14 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (14 - 12); out++; - *out = ((*in) >> 12) % (1U << 14); + *out = (inl >> 12) % (1U << 14); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 8)) << (14 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (14 - 8); out++; - *out = ((*in) >> 8) % (1U << 14); + *out = (inl >> 8) % (1U << 14); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 4)) << (14 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (14 - 4); out++; - *out = ((*in) >> 4) % (1U << 14); + *out = (inl >> 4) % (1U << 14); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; out++; @@ -1178,97 +1284,112 @@ inline const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 15); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 15); out++; - *out = ((*in) >> 15) % (1U << 15); + *out = (inl >> 15) % (1U << 15); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 13)) << (15 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (15 - 13); out++; - *out = ((*in) >> 13) % (1U << 15); + *out = (inl >> 13) % (1U << 15); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 11)) << (15 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (15 - 11); out++; - *out = ((*in) >> 11) % (1U << 15); + *out = (inl >> 11) % (1U << 15); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 9)) << (15 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (15 - 9); out++; - *out = ((*in) >> 9) % (1U << 15); + *out = (inl >> 9) % (1U << 15); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 7)) << (15 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (15 - 7); out++; - *out = ((*in) >> 7) % (1U << 15); + *out = (inl >> 7) % (1U << 15); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 5)) << (15 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (15 - 5); out++; - *out = ((*in) >> 5) % (1U << 15); + *out = (inl >> 5) % (1U << 15); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 3)) << (15 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (15 - 3); out++; - *out = ((*in) >> 3) % (1U << 15); + *out = (inl >> 3) % (1U << 15); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 1)) << (15 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (15 - 1); out++; - *out = ((*in) >> 1) % (1U << 15); + *out = (inl >> 1) % (1U << 15); out++; - *out = ((*in) >> 16) % (1U << 15); + *out = (inl >> 16) % (1U << 15); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 14)) << (15 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (15 - 14); out++; - *out = ((*in) >> 14) % (1U << 15); + *out = (inl >> 14) % (1U << 15); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 12)) << (15 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (15 - 12); out++; - *out = ((*in) >> 12) % (1U << 15); + *out = (inl >> 12) % (1U << 15); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 10)) << (15 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (15 - 10); out++; - *out = ((*in) >> 10) % (1U << 15); + *out = (inl >> 10) % (1U << 15); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 8)) << (15 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (15 - 8); out++; - *out = ((*in) >> 8) % (1U << 15); + *out = (inl >> 8) % (1U << 15); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 6)) << (15 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (15 - 6); out++; - *out = ((*in) >> 6) % (1U << 15); + *out = (inl >> 6) % (1U << 15); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 4)) << (15 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (15 - 4); out++; - *out = ((*in) >> 4) % (1U << 15); + *out = (inl >> 4) % (1U << 15); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 2)) << (15 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (15 - 2); out++; - *out = ((*in) >> 2) % (1U << 15); + *out = (inl >> 2) % (1U << 15); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; out++; @@ -1276,84 +1397,100 @@ inline const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 16); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 16); + *out = (inl >> 0) % (1U << 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; out++; @@ -1361,101 +1498,118 @@ inline const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 17); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 2)) << (17 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (17 - 2); out++; - *out = ((*in) >> 2) % (1U << 17); + *out = (inl >> 2) % (1U << 17); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 4)) << (17 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (17 - 4); out++; - *out = ((*in) >> 4) % (1U << 17); + *out = (inl >> 4) % (1U << 17); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 6)) << (17 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (17 - 6); out++; - *out = ((*in) >> 6) % (1U << 17); + *out = (inl >> 6) % (1U << 17); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 8)) << (17 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (17 - 8); out++; - *out = ((*in) >> 8) % (1U << 17); + *out = (inl >> 8) % (1U << 17); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 10)) << (17 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (17 - 10); out++; - *out = ((*in) >> 10) % (1U << 17); + *out = (inl >> 10) % (1U << 17); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 12)) << (17 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (17 - 12); out++; - *out = ((*in) >> 12) % (1U << 17); + *out = (inl >> 12) % (1U << 17); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 14)) << (17 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (17 - 14); out++; - *out = ((*in) >> 14) % (1U << 17); + *out = (inl >> 14) % (1U << 17); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 16)) << (17 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (17 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 1)) << (17 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (17 - 1); out++; - *out = ((*in) >> 1) % (1U << 17); + *out = (inl >> 1) % (1U << 17); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 3)) << (17 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (17 - 3); out++; - *out = ((*in) >> 3) % (1U << 17); + *out = (inl >> 3) % (1U << 17); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 5)) << (17 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (17 - 5); out++; - *out = ((*in) >> 5) % (1U << 17); + *out = (inl >> 5) % (1U << 17); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 7)) << (17 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (17 - 7); out++; - *out = ((*in) >> 7) % (1U << 17); + *out = (inl >> 7) % (1U << 17); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 9)) << (17 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (17 - 9); out++; - *out = ((*in) >> 9) % (1U << 17); + *out = (inl >> 9) % (1U << 17); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 11)) << (17 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (17 - 11); out++; - *out = ((*in) >> 11) % (1U << 17); + *out = (inl >> 11) % (1U << 17); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 13)) << (17 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (17 - 13); out++; - *out = ((*in) >> 13) % (1U << 17); + *out = (inl >> 13) % (1U << 17); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 15)) << (17 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (17 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; out++; @@ -1463,102 +1617,120 @@ inline const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 18); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 4)) << (18 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (18 - 4); out++; - *out = ((*in) >> 4) % (1U << 18); + *out = (inl >> 4) % (1U << 18); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 8)) << (18 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (18 - 8); out++; - *out = ((*in) >> 8) % (1U << 18); + *out = (inl >> 8) % (1U << 18); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 12)) << (18 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (18 - 12); out++; - *out = ((*in) >> 12) % (1U << 18); + *out = (inl >> 12) % (1U << 18); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 16)) << (18 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (18 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 2)) << (18 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (18 - 2); out++; - *out = ((*in) >> 2) % (1U << 18); + *out = (inl >> 2) % (1U << 18); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 6)) << (18 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (18 - 6); out++; - *out = ((*in) >> 6) % (1U << 18); + *out = (inl >> 6) % (1U << 18); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 10)) << (18 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (18 - 10); out++; - *out = ((*in) >> 10) % (1U << 18); + *out = (inl >> 10) % (1U << 18); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 14)) << (18 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (18 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 18); + *out = (inl >> 0) % (1U << 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 4)) << (18 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (18 - 4); out++; - *out = ((*in) >> 4) % (1U << 18); + *out = (inl >> 4) % (1U << 18); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 8)) << (18 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (18 - 8); out++; - *out = ((*in) >> 8) % (1U << 18); + *out = (inl >> 8) % (1U << 18); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 12)) << (18 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (18 - 12); out++; - *out = ((*in) >> 12) % (1U << 18); + *out = (inl >> 12) % (1U << 18); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 16)) << (18 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (18 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 2)) << (18 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (18 - 2); out++; - *out = ((*in) >> 2) % (1U << 18); + *out = (inl >> 2) % (1U << 18); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 6)) << (18 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (18 - 6); out++; - *out = ((*in) >> 6) % (1U << 18); + *out = (inl >> 6) % (1U << 18); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 10)) << (18 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (18 - 10); out++; - *out = ((*in) >> 10) % (1U << 18); + *out = (inl >> 10) % (1U << 18); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 14)) << (18 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (18 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; out++; @@ -1566,105 +1738,124 @@ inline const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 19); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 6)) << (19 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (19 - 6); out++; - *out = ((*in) >> 6) % (1U << 19); + *out = (inl >> 6) % (1U << 19); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 12)) << (19 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (19 - 12); out++; - *out = ((*in) >> 12) % (1U << 19); + *out = (inl >> 12) % (1U << 19); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 18)) << (19 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (19 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 5)) << (19 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (19 - 5); out++; - *out = ((*in) >> 5) % (1U << 19); + *out = (inl >> 5) % (1U << 19); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 11)) << (19 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (19 - 11); out++; - *out = ((*in) >> 11) % (1U << 19); + *out = (inl >> 11) % (1U << 19); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 17)) << (19 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (19 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 4)) << (19 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (19 - 4); out++; - *out = ((*in) >> 4) % (1U << 19); + *out = (inl >> 4) % (1U << 19); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 10)) << (19 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (19 - 10); out++; - *out = ((*in) >> 10) % (1U << 19); + *out = (inl >> 10) % (1U << 19); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 16)) << (19 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (19 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 3)) << (19 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (19 - 3); out++; - *out = ((*in) >> 3) % (1U << 19); + *out = (inl >> 3) % (1U << 19); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 9)) << (19 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (19 - 9); out++; - *out = ((*in) >> 9) % (1U << 19); + *out = (inl >> 9) % (1U << 19); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 15)) << (19 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (19 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 2)) << (19 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (19 - 2); out++; - *out = ((*in) >> 2) % (1U << 19); + *out = (inl >> 2) % (1U << 19); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 8)) << (19 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (19 - 8); out++; - *out = ((*in) >> 8) % (1U << 19); + *out = (inl >> 8) % (1U << 19); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 14)) << (19 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (19 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 1)) << (19 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (19 - 1); out++; - *out = ((*in) >> 1) % (1U << 19); + *out = (inl >> 1) % (1U << 19); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 7)) << (19 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (19 - 7); out++; - *out = ((*in) >> 7) % (1U << 19); + *out = (inl >> 7) % (1U << 19); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 13)) << (19 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (19 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; out++; @@ -1672,104 +1863,124 @@ inline const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 20); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 8)) << (20 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (20 - 8); out++; - *out = ((*in) >> 8) % (1U << 20); + *out = (inl >> 8) % (1U << 20); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 16)) << (20 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (20 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 4)) << (20 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (20 - 4); out++; - *out = ((*in) >> 4) % (1U << 20); + *out = (inl >> 4) % (1U << 20); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 12)) << (20 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (20 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 20); + *out = (inl >> 0) % (1U << 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 8)) << (20 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (20 - 8); out++; - *out = ((*in) >> 8) % (1U << 20); + *out = (inl >> 8) % (1U << 20); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 16)) << (20 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (20 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 4)) << (20 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (20 - 4); out++; - *out = ((*in) >> 4) % (1U << 20); + *out = (inl >> 4) % (1U << 20); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 12)) << (20 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (20 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 20); + *out = (inl >> 0) % (1U << 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 8)) << (20 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (20 - 8); out++; - *out = ((*in) >> 8) % (1U << 20); + *out = (inl >> 8) % (1U << 20); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 16)) << (20 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (20 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 4)) << (20 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (20 - 4); out++; - *out = ((*in) >> 4) % (1U << 20); + *out = (inl >> 4) % (1U << 20); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 12)) << (20 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (20 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 20); + *out = (inl >> 0) % (1U << 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 8)) << (20 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (20 - 8); out++; - *out = ((*in) >> 8) % (1U << 20); + *out = (inl >> 8) % (1U << 20); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 16)) << (20 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (20 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 4)) << (20 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (20 - 4); out++; - *out = ((*in) >> 4) % (1U << 20); + *out = (inl >> 4) % (1U << 20); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 12)) << (20 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (20 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; out++; @@ -1777,109 +1988,130 @@ inline const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 21); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 10)) << (21 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (21 - 10); out++; - *out = ((*in) >> 10) % (1U << 21); + *out = (inl >> 10) % (1U << 21); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 20)) << (21 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (21 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 9)) << (21 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (21 - 9); out++; - *out = ((*in) >> 9) % (1U << 21); + *out = (inl >> 9) % (1U << 21); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 19)) << (21 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (21 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 8)) << (21 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (21 - 8); out++; - *out = ((*in) >> 8) % (1U << 21); + *out = (inl >> 8) % (1U << 21); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 18)) << (21 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (21 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 7)) << (21 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (21 - 7); out++; - *out = ((*in) >> 7) % (1U << 21); + *out = (inl >> 7) % (1U << 21); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 17)) << (21 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (21 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 6)) << (21 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (21 - 6); out++; - *out = ((*in) >> 6) % (1U << 21); + *out = (inl >> 6) % (1U << 21); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 16)) << (21 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (21 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 5)) << (21 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (21 - 5); out++; - *out = ((*in) >> 5) % (1U << 21); + *out = (inl >> 5) % (1U << 21); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 15)) << (21 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (21 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 4)) << (21 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (21 - 4); out++; - *out = ((*in) >> 4) % (1U << 21); + *out = (inl >> 4) % (1U << 21); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 14)) << (21 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (21 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 3)) << (21 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (21 - 3); out++; - *out = ((*in) >> 3) % (1U << 21); + *out = (inl >> 3) % (1U << 21); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 13)) << (21 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (21 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 2)) << (21 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (21 - 2); out++; - *out = ((*in) >> 2) % (1U << 21); + *out = (inl >> 2) % (1U << 21); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 12)) << (21 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (21 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 1)) << (21 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (21 - 1); out++; - *out = ((*in) >> 1) % (1U << 21); + *out = (inl >> 1) % (1U << 21); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 11)) << (21 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (21 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; out++; @@ -1887,110 +2119,132 @@ inline const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 22); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 12)) << (22 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (22 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 2)) << (22 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (22 - 2); out++; - *out = ((*in) >> 2) % (1U << 22); + *out = (inl >> 2) % (1U << 22); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 14)) << (22 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (22 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 4)) << (22 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (22 - 4); out++; - *out = ((*in) >> 4) % (1U << 22); + *out = (inl >> 4) % (1U << 22); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 16)) << (22 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (22 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 6)) << (22 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (22 - 6); out++; - *out = ((*in) >> 6) % (1U << 22); + *out = (inl >> 6) % (1U << 22); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 18)) << (22 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (22 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 8)) << (22 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (22 - 8); out++; - *out = ((*in) >> 8) % (1U << 22); + *out = (inl >> 8) % (1U << 22); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 20)) << (22 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (22 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 10)) << (22 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (22 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 22); + *out = (inl >> 0) % (1U << 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 12)) << (22 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (22 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 2)) << (22 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (22 - 2); out++; - *out = ((*in) >> 2) % (1U << 22); + *out = (inl >> 2) % (1U << 22); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 14)) << (22 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (22 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 4)) << (22 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (22 - 4); out++; - *out = ((*in) >> 4) % (1U << 22); + *out = (inl >> 4) % (1U << 22); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 16)) << (22 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (22 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 6)) << (22 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (22 - 6); out++; - *out = ((*in) >> 6) % (1U << 22); + *out = (inl >> 6) % (1U << 22); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 18)) << (22 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (22 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 8)) << (22 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (22 - 8); out++; - *out = ((*in) >> 8) % (1U << 22); + *out = (inl >> 8) % (1U << 22); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 20)) << (22 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (22 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 10)) << (22 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (22 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; out++; @@ -1998,113 +2252,136 @@ inline const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 23); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 23); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 14)) << (23 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (23 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 5)) << (23 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (23 - 5); out++; - *out = ((*in) >> 5) % (1U << 23); + *out = (inl >> 5) % (1U << 23); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 19)) << (23 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (23 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 10)) << (23 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (23 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 1)) << (23 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (23 - 1); out++; - *out = ((*in) >> 1) % (1U << 23); + *out = (inl >> 1) % (1U << 23); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 15)) << (23 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (23 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 6)) << (23 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (23 - 6); out++; - *out = ((*in) >> 6) % (1U << 23); + *out = (inl >> 6) % (1U << 23); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 20)) << (23 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (23 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 11)) << (23 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (23 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; - *out |= ((*in) % (1U << 2)) << (23 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (23 - 2); out++; - *out = ((*in) >> 2) % (1U << 23); + *out = (inl >> 2) % (1U << 23); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 16)) << (23 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (23 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 7)) << (23 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (23 - 7); out++; - *out = ((*in) >> 7) % (1U << 23); + *out = (inl >> 7) % (1U << 23); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 21)) << (23 - 21); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 21)) << (23 - 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 12)) << (23 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (23 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 3)) << (23 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (23 - 3); out++; - *out = ((*in) >> 3) % (1U << 23); + *out = (inl >> 3) % (1U << 23); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 17)) << (23 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (23 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 8)) << (23 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (23 - 8); out++; - *out = ((*in) >> 8) % (1U << 23); + *out = (inl >> 8) % (1U << 23); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 22)) << (23 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (23 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 13)) << (23 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (23 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 4)) << (23 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (23 - 4); out++; - *out = ((*in) >> 4) % (1U << 23); + *out = (inl >> 4) % (1U << 23); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 18)) << (23 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (23 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 9)) << (23 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (23 - 9); out++; - *out = ((*in) >> 9); + *out = (inl >> 9); ++in; out++; @@ -2112,108 +2389,132 @@ inline const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 24); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 24); + *out = (inl >> 0) % (1U << 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 16)) << (24 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (24 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 8)) << (24 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (24 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; out++; @@ -2221,117 +2522,142 @@ inline const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 25); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 25); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 18)) << (25 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (25 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 11)) << (25 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (25 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; - *out |= ((*in) % (1U << 4)) << (25 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (25 - 4); out++; - *out = ((*in) >> 4) % (1U << 25); + *out = (inl >> 4) % (1U << 25); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 22)) << (25 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (25 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 15)) << (25 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (25 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 8)) << (25 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (25 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 1)) << (25 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (25 - 1); out++; - *out = ((*in) >> 1) % (1U << 25); + *out = (inl >> 1) % (1U << 25); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 19)) << (25 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (25 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 12)) << (25 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (25 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 5)) << (25 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (25 - 5); out++; - *out = ((*in) >> 5) % (1U << 25); + *out = (inl >> 5) % (1U << 25); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 23)) << (25 - 23); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 23)) << (25 - 23); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 16)) << (25 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (25 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 9)) << (25 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (25 - 9); out++; - *out = ((*in) >> 9); + *out = (inl >> 9); ++in; - *out |= ((*in) % (1U << 2)) << (25 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (25 - 2); out++; - *out = ((*in) >> 2) % (1U << 25); + *out = (inl >> 2) % (1U << 25); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 20)) << (25 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (25 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 13)) << (25 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (25 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 6)) << (25 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (25 - 6); out++; - *out = ((*in) >> 6) % (1U << 25); + *out = (inl >> 6) % (1U << 25); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 24)) << (25 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (25 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 17)) << (25 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (25 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 10)) << (25 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (25 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 3)) << (25 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (25 - 3); out++; - *out = ((*in) >> 3) % (1U << 25); + *out = (inl >> 3) % (1U << 25); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 21)) << (25 - 21); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 21)) << (25 - 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 14)) << (25 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (25 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 7)) << (25 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (25 - 7); out++; - *out = ((*in) >> 7); + *out = (inl >> 7); ++in; out++; @@ -2339,118 +2665,144 @@ inline const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 26); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 20)) << (26 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (26 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 14)) << (26 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (26 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 8)) << (26 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (26 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 2)) << (26 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (26 - 2); out++; - *out = ((*in) >> 2) % (1U << 26); + *out = (inl >> 2) % (1U << 26); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 22)) << (26 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (26 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 16)) << (26 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (26 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 10)) << (26 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (26 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 4)) << (26 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (26 - 4); out++; - *out = ((*in) >> 4) % (1U << 26); + *out = (inl >> 4) % (1U << 26); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 24)) << (26 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (26 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 18)) << (26 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (26 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 12)) << (26 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (26 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 6)) << (26 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (26 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 26); + *out = (inl >> 0) % (1U << 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 20)) << (26 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (26 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 14)) << (26 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (26 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 8)) << (26 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (26 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 2)) << (26 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (26 - 2); out++; - *out = ((*in) >> 2) % (1U << 26); + *out = (inl >> 2) % (1U << 26); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 22)) << (26 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (26 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 16)) << (26 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (26 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 10)) << (26 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (26 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 4)) << (26 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (26 - 4); out++; - *out = ((*in) >> 4) % (1U << 26); + *out = (inl >> 4) % (1U << 26); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 24)) << (26 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (26 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 18)) << (26 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (26 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 12)) << (26 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (26 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 6)) << (26 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (26 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; out++; @@ -2458,121 +2810,148 @@ inline const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 27); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 27); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 22)) << (27 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (27 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 17)) << (27 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (27 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 12)) << (27 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (27 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 7)) << (27 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (27 - 7); out++; - *out = ((*in) >> 7); + *out = (inl >> 7); ++in; - *out |= ((*in) % (1U << 2)) << (27 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (27 - 2); out++; - *out = ((*in) >> 2) % (1U << 27); + *out = (inl >> 2) % (1U << 27); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 24)) << (27 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (27 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 19)) << (27 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (27 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 14)) << (27 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (27 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 9)) << (27 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (27 - 9); out++; - *out = ((*in) >> 9); + *out = (inl >> 9); ++in; - *out |= ((*in) % (1U << 4)) << (27 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (27 - 4); out++; - *out = ((*in) >> 4) % (1U << 27); + *out = (inl >> 4) % (1U << 27); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 26)) << (27 - 26); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 26)) << (27 - 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 21)) << (27 - 21); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 21)) << (27 - 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 16)) << (27 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (27 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 11)) << (27 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (27 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; - *out |= ((*in) % (1U << 6)) << (27 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (27 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; - *out |= ((*in) % (1U << 1)) << (27 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (27 - 1); out++; - *out = ((*in) >> 1) % (1U << 27); + *out = (inl >> 1) % (1U << 27); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 23)) << (27 - 23); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 23)) << (27 - 23); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 18)) << (27 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (27 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 13)) << (27 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (27 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 8)) << (27 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (27 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 3)) << (27 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (27 - 3); out++; - *out = ((*in) >> 3) % (1U << 27); + *out = (inl >> 3) % (1U << 27); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 25)) << (27 - 25); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 25)) << (27 - 25); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 20)) << (27 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (27 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 15)) << (27 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (27 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 10)) << (27 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (27 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 5)) << (27 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (27 - 5); out++; - *out = ((*in) >> 5); + *out = (inl >> 5); ++in; out++; @@ -2580,120 +2959,148 @@ inline const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 28); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 24)) << (28 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (28 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 20)) << (28 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (28 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 16)) << (28 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (28 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 12)) << (28 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (28 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 8)) << (28 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (28 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 4)) << (28 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (28 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 28); + *out = (inl >> 0) % (1U << 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 24)) << (28 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (28 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 20)) << (28 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (28 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 16)) << (28 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (28 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 12)) << (28 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (28 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 8)) << (28 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (28 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 4)) << (28 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (28 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 28); + *out = (inl >> 0) % (1U << 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 24)) << (28 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (28 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 20)) << (28 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (28 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 16)) << (28 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (28 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 12)) << (28 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (28 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 8)) << (28 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (28 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 4)) << (28 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (28 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 28); + *out = (inl >> 0) % (1U << 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 24)) << (28 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (28 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 20)) << (28 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (28 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 16)) << (28 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (28 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 12)) << (28 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (28 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 8)) << (28 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (28 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 4)) << (28 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (28 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; out++; @@ -2701,125 +3108,154 @@ inline const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 29); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 29); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 26)) << (29 - 26); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 26)) << (29 - 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 23)) << (29 - 23); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 23)) << (29 - 23); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 20)) << (29 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (29 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 17)) << (29 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (29 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 14)) << (29 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (29 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 11)) << (29 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (29 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; - *out |= ((*in) % (1U << 8)) << (29 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (29 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 5)) << (29 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (29 - 5); out++; - *out = ((*in) >> 5); + *out = (inl >> 5); ++in; - *out |= ((*in) % (1U << 2)) << (29 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (29 - 2); out++; - *out = ((*in) >> 2) % (1U << 29); + *out = (inl >> 2) % (1U << 29); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 28)) << (29 - 28); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 28)) << (29 - 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 25)) << (29 - 25); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 25)) << (29 - 25); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 22)) << (29 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (29 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 19)) << (29 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (29 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 16)) << (29 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (29 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 13)) << (29 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (29 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 10)) << (29 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (29 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 7)) << (29 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (29 - 7); out++; - *out = ((*in) >> 7); + *out = (inl >> 7); ++in; - *out |= ((*in) % (1U << 4)) << (29 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (29 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; - *out |= ((*in) % (1U << 1)) << (29 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (29 - 1); out++; - *out = ((*in) >> 1) % (1U << 29); + *out = (inl >> 1) % (1U << 29); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 27)) << (29 - 27); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 27)) << (29 - 27); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 24)) << (29 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (29 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 21)) << (29 - 21); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 21)) << (29 - 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 18)) << (29 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (29 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 15)) << (29 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (29 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 12)) << (29 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (29 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 9)) << (29 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (29 - 9); out++; - *out = ((*in) >> 9); + *out = (inl >> 9); ++in; - *out |= ((*in) % (1U << 6)) << (29 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (29 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; - *out |= ((*in) % (1U << 3)) << (29 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (29 - 3); out++; - *out = ((*in) >> 3); + *out = (inl >> 3); ++in; out++; @@ -2827,126 +3263,156 @@ inline const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 30); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 30); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 28)) << (30 - 28); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 28)) << (30 - 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 26)) << (30 - 26); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 26)) << (30 - 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 24)) << (30 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (30 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 22)) << (30 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (30 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 20)) << (30 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (30 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 18)) << (30 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (30 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 16)) << (30 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (30 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 14)) << (30 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (30 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 12)) << (30 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (30 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 10)) << (30 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (30 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 8)) << (30 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (30 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 6)) << (30 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (30 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; - *out |= ((*in) % (1U << 4)) << (30 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (30 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; - *out |= ((*in) % (1U << 2)) << (30 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (30 - 2); out++; - *out = ((*in) >> 2); + *out = (inl >> 2); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0) % (1U << 30); + *out = (inl >> 0) % (1U << 30); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 28)) << (30 - 28); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 28)) << (30 - 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 26)) << (30 - 26); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 26)) << (30 - 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 24)) << (30 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (30 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 22)) << (30 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (30 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 20)) << (30 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (30 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 18)) << (30 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (30 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 16)) << (30 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (30 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 14)) << (30 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (30 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 12)) << (30 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (30 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 10)) << (30 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (30 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 8)) << (30 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (30 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 6)) << (30 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (30 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; - *out |= ((*in) % (1U << 4)) << (30 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (30 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; - *out |= ((*in) % (1U << 2)) << (30 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (30 - 2); out++; - *out = ((*in) >> 2); + *out = (inl >> 2); ++in; out++; @@ -2954,129 +3420,160 @@ inline const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0) % (1U << 31); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0) % (1U << 31); out++; - *out = ((*in) >> 31); + *out = (inl >> 31); ++in; - *out |= ((*in) % (1U << 30)) << (31 - 30); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 30)) << (31 - 30); out++; - *out = ((*in) >> 30); + *out = (inl >> 30); ++in; - *out |= ((*in) % (1U << 29)) << (31 - 29); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 29)) << (31 - 29); out++; - *out = ((*in) >> 29); + *out = (inl >> 29); ++in; - *out |= ((*in) % (1U << 28)) << (31 - 28); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 28)) << (31 - 28); out++; - *out = ((*in) >> 28); + *out = (inl >> 28); ++in; - *out |= ((*in) % (1U << 27)) << (31 - 27); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 27)) << (31 - 27); out++; - *out = ((*in) >> 27); + *out = (inl >> 27); ++in; - *out |= ((*in) % (1U << 26)) << (31 - 26); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 26)) << (31 - 26); out++; - *out = ((*in) >> 26); + *out = (inl >> 26); ++in; - *out |= ((*in) % (1U << 25)) << (31 - 25); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 25)) << (31 - 25); out++; - *out = ((*in) >> 25); + *out = (inl >> 25); ++in; - *out |= ((*in) % (1U << 24)) << (31 - 24); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 24)) << (31 - 24); out++; - *out = ((*in) >> 24); + *out = (inl >> 24); ++in; - *out |= ((*in) % (1U << 23)) << (31 - 23); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 23)) << (31 - 23); out++; - *out = ((*in) >> 23); + *out = (inl >> 23); ++in; - *out |= ((*in) % (1U << 22)) << (31 - 22); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 22)) << (31 - 22); out++; - *out = ((*in) >> 22); + *out = (inl >> 22); ++in; - *out |= ((*in) % (1U << 21)) << (31 - 21); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 21)) << (31 - 21); out++; - *out = ((*in) >> 21); + *out = (inl >> 21); ++in; - *out |= ((*in) % (1U << 20)) << (31 - 20); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 20)) << (31 - 20); out++; - *out = ((*in) >> 20); + *out = (inl >> 20); ++in; - *out |= ((*in) % (1U << 19)) << (31 - 19); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 19)) << (31 - 19); out++; - *out = ((*in) >> 19); + *out = (inl >> 19); ++in; - *out |= ((*in) % (1U << 18)) << (31 - 18); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 18)) << (31 - 18); out++; - *out = ((*in) >> 18); + *out = (inl >> 18); ++in; - *out |= ((*in) % (1U << 17)) << (31 - 17); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 17)) << (31 - 17); out++; - *out = ((*in) >> 17); + *out = (inl >> 17); ++in; - *out |= ((*in) % (1U << 16)) << (31 - 16); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 16)) << (31 - 16); out++; - *out = ((*in) >> 16); + *out = (inl >> 16); ++in; - *out |= ((*in) % (1U << 15)) << (31 - 15); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 15)) << (31 - 15); out++; - *out = ((*in) >> 15); + *out = (inl >> 15); ++in; - *out |= ((*in) % (1U << 14)) << (31 - 14); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 14)) << (31 - 14); out++; - *out = ((*in) >> 14); + *out = (inl >> 14); ++in; - *out |= ((*in) % (1U << 13)) << (31 - 13); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 13)) << (31 - 13); out++; - *out = ((*in) >> 13); + *out = (inl >> 13); ++in; - *out |= ((*in) % (1U << 12)) << (31 - 12); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 12)) << (31 - 12); out++; - *out = ((*in) >> 12); + *out = (inl >> 12); ++in; - *out |= ((*in) % (1U << 11)) << (31 - 11); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 11)) << (31 - 11); out++; - *out = ((*in) >> 11); + *out = (inl >> 11); ++in; - *out |= ((*in) % (1U << 10)) << (31 - 10); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 10)) << (31 - 10); out++; - *out = ((*in) >> 10); + *out = (inl >> 10); ++in; - *out |= ((*in) % (1U << 9)) << (31 - 9); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 9)) << (31 - 9); out++; - *out = ((*in) >> 9); + *out = (inl >> 9); ++in; - *out |= ((*in) % (1U << 8)) << (31 - 8); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 8)) << (31 - 8); out++; - *out = ((*in) >> 8); + *out = (inl >> 8); ++in; - *out |= ((*in) % (1U << 7)) << (31 - 7); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 7)) << (31 - 7); out++; - *out = ((*in) >> 7); + *out = (inl >> 7); ++in; - *out |= ((*in) % (1U << 6)) << (31 - 6); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 6)) << (31 - 6); out++; - *out = ((*in) >> 6); + *out = (inl >> 6); ++in; - *out |= ((*in) % (1U << 5)) << (31 - 5); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 5)) << (31 - 5); out++; - *out = ((*in) >> 5); + *out = (inl >> 5); ++in; - *out |= ((*in) % (1U << 4)) << (31 - 4); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 4)) << (31 - 4); out++; - *out = ((*in) >> 4); + *out = (inl >> 4); ++in; - *out |= ((*in) % (1U << 3)) << (31 - 3); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 3)) << (31 - 3); out++; - *out = ((*in) >> 3); + *out = (inl >> 3); ++in; - *out |= ((*in) % (1U << 2)) << (31 - 2); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 2)) << (31 - 2); out++; - *out = ((*in) >> 2); + *out = (inl >> 2); ++in; - *out |= ((*in) % (1U << 1)) << (31 - 1); + inl = util::SafeLoad(in); + *out |= (inl % (1U << 1)) << (31 - 1); out++; - *out = ((*in) >> 1); + *out = (inl >> 1); ++in; out++; @@ -3084,100 +3581,132 @@ inline const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) { } inline const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) { - *out = ((*in) >> 0); + uint32_t inl = util::SafeLoad(in); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; + inl = util::SafeLoad(in); out++; - *out = ((*in) >> 0); + *out = (inl >> 0); ++in; out++; diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 49641d81c08..c053da8ae15 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -149,9 +149,8 @@ hash_t ComputeStringHash(const void* data, int64_t length) { // the results uint32_t x, y; hash_t hx, hy; - // XXX those are unaligned accesses. Should we have a facility for that? - x = *reinterpret_cast(p + n - 4); - y = *reinterpret_cast(p); + x = util::SafeLoadAs(p + n - 4); + y = util::SafeLoadAs(p); hx = ScalarHelper::ComputeHash(x); hy = ScalarHelper::ComputeHash(y); return n ^ hx ^ hy; @@ -160,8 +159,8 @@ hash_t ComputeStringHash(const void* data, int64_t length) { // Apply the same principle as above uint64_t x, y; hash_t hx, hy; - x = *reinterpret_cast(p + n - 8); - y = *reinterpret_cast(p); + x = util::SafeLoadAs(p + n - 8); + y = util::SafeLoadAs(p); hx = ScalarHelper::ComputeHash(x); hy = ScalarHelper::ComputeHash(y); return n ^ hx ^ hy; diff --git a/cpp/src/arrow/util/ubsan.h b/cpp/src/arrow/util/ubsan.h index f9fcfb54022..758f5425a59 100644 --- a/cpp/src/arrow/util/ubsan.h +++ b/cpp/src/arrow/util/ubsan.h @@ -49,5 +49,21 @@ inline T* MakeNonNull(T* maybe_null) { return reinterpret_cast(&internal::non_null_filler); } +template +inline typename std::enable_if::value, T>::type SafeLoadAs( + const uint8_t* unaligned) { + typename std::remove_const::type ret; + std::memcpy(&ret, unaligned, sizeof(T)); + return ret; +} + +template +inline typename std::enable_if::value, T>::type SafeLoad( + const T* unaligned) { + typename std::remove_const::type ret; + std::memcpy(&ret, unaligned, sizeof(T)); + return ret; +} + } // namespace util } // namespace arrow diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc index 3fe37b0e239..f757b5ff847 100644 --- a/cpp/src/parquet/arrow/reader.cc +++ b/cpp/src/parquet/arrow/reader.cc @@ -83,6 +83,7 @@ namespace arrow { using ::arrow::BitUtil::FromBigEndian; using ::arrow::internal::SafeLeftShift; +using ::arrow::util::SafeLoadAs; template using ArrayType = typename ::arrow::TypeTraits::ArrayType; @@ -1212,38 +1213,37 @@ static uint64_t BytesToInteger(const uint8_t* bytes, int32_t start, int32_t stop case 1: return bytes[start]; case 2: - return FromBigEndian(*reinterpret_cast(bytes + start)); + return FromBigEndian(SafeLoadAs(bytes + start)); case 3: { - const uint64_t first_two_bytes = - FromBigEndian(*reinterpret_cast(bytes + start)); + const uint64_t first_two_bytes = FromBigEndian(SafeLoadAs(bytes + start)); const uint64_t last_byte = bytes[stop - 1]; return first_two_bytes << 8 | last_byte; } case 4: - return FromBigEndian(*reinterpret_cast(bytes + start)); + return FromBigEndian(SafeLoadAs(bytes + start)); case 5: { const uint64_t first_four_bytes = - FromBigEndian(*reinterpret_cast(bytes + start)); + FromBigEndian(SafeLoadAs(bytes + start)); const uint64_t last_byte = bytes[stop - 1]; return first_four_bytes << 8 | last_byte; } case 6: { const uint64_t first_four_bytes = - FromBigEndian(*reinterpret_cast(bytes + start)); + FromBigEndian(SafeLoadAs(bytes + start)); const uint64_t last_two_bytes = - FromBigEndian(*reinterpret_cast(bytes + start + 4)); + FromBigEndian(SafeLoadAs(bytes + start + 4)); return first_four_bytes << 16 | last_two_bytes; } case 7: { const uint64_t first_four_bytes = - FromBigEndian(*reinterpret_cast(bytes + start)); + FromBigEndian(SafeLoadAs(bytes + start)); const uint64_t second_two_bytes = - FromBigEndian(*reinterpret_cast(bytes + start + 4)); + FromBigEndian(SafeLoadAs(bytes + start + 4)); const uint64_t last_byte = bytes[stop - 1]; return first_four_bytes << 24 | second_two_bytes << 8 | last_byte; } case 8: - return FromBigEndian(*reinterpret_cast(bytes + start)); + return FromBigEndian(SafeLoadAs(bytes + start)); default: { DCHECK(false); return UINT64_MAX; diff --git a/cpp/src/parquet/arrow/writer.h b/cpp/src/parquet/arrow/writer.h index 8014e1a3511..5a72da6b59a 100644 --- a/cpp/src/parquet/arrow/writer.h +++ b/cpp/src/parquet/arrow/writer.h @@ -211,8 +211,9 @@ inline void ArrowTimestampToImpalaTimestamp(const int64_t time, Int96* impala_ti (*impala_timestamp).value[2] = (uint32_t)julian_days; int64_t last_day_units = time % UnitPerDay; - int64_t* impala_last_day_nanos = reinterpret_cast(impala_timestamp); - *impala_last_day_nanos = last_day_units * NanosecondsPerUnit; + auto last_day_nanos = last_day_units * NanosecondsPerUnit; + // Strage might be unaligned, so use mempcy instead of reinterpret_cast + std::memcpy(impala_timestamp, &last_day_nanos, sizeof(int64_t)); } constexpr int64_t kSecondsInNanos = INT64_C(1000000000); diff --git a/cpp/src/parquet/column_reader.cc b/cpp/src/parquet/column_reader.cc index f66224edd47..130b75a5210 100644 --- a/cpp/src/parquet/column_reader.cc +++ b/cpp/src/parquet/column_reader.cc @@ -27,6 +27,7 @@ #include "arrow/util/compression.h" #include "arrow/util/logging.h" #include "arrow/util/rle-encoding.h" +#include "arrow/util/ubsan.h" #include "parquet/column_page.h" #include "parquet/encoding.h" @@ -50,7 +51,7 @@ int LevelDecoder::SetData(Encoding::type encoding, int16_t max_level, bit_width_ = BitUtil::Log2(max_level + 1); switch (encoding) { case Encoding::RLE: { - num_bytes = *reinterpret_cast(data); + num_bytes = arrow::util::SafeLoadAs(data); const uint8_t* decoder_data = data + sizeof(int32_t); if (!rle_decoder_) { rle_decoder_.reset( diff --git a/cpp/src/parquet/encoding.cc b/cpp/src/parquet/encoding.cc index 77f86e36f9b..304724b6b52 100644 --- a/cpp/src/parquet/encoding.cc +++ b/cpp/src/parquet/encoding.cc @@ -29,6 +29,7 @@ #include "arrow/util/logging.h" #include "arrow/util/rle-encoding.h" #include "arrow/util/string_view.h" +#include "arrow/util/ubsan.h" #include "parquet/exception.h" #include "parquet/platform.h" @@ -609,7 +610,7 @@ inline int DecodePlain(const uint8_t* data, int64_t data_size, int nu int bytes_decoded = 0; int increment; for (int i = 0; i < num_values; ++i) { - uint32_t len = out[i].len = *reinterpret_cast(data); + uint32_t len = out[i].len = arrow::util::SafeLoadAs(data); increment = static_cast(sizeof(uint32_t) + len); if (data_size < increment) ParquetException::EofException(); out[i].ptr = data + sizeof(uint32_t); @@ -719,7 +720,7 @@ class PlainByteArrayDecoder : public PlainDecoder, int bytes_decoded = 0; while (i < num_values) { if (bit_reader.IsSet()) { - uint32_t len = *reinterpret_cast(data); + uint32_t len = arrow::util::SafeLoadAs(data); increment = static_cast(sizeof(uint32_t) + len); if (data_size < increment) { ParquetException::EofException(); @@ -752,7 +753,7 @@ class PlainByteArrayDecoder : public PlainDecoder, int bytes_decoded = 0; while (i < num_values) { - uint32_t len = *reinterpret_cast(data); + uint32_t len = arrow::util::SafeLoadAs(data); int increment = static_cast(sizeof(uint32_t) + len); if (data_size < increment) ParquetException::EofException(); builder->Append(data + sizeof(uint32_t), len); @@ -1103,7 +1104,7 @@ class DeltaLengthByteArrayDecoder : public DecoderImpl, virtual void SetData(int num_values, const uint8_t* data, int len) { num_values_ = num_values; if (len == 0) return; - int total_lengths_len = *reinterpret_cast(data); + int total_lengths_len = arrow::util::SafeLoadAs(data); data += 4; this->len_decoder_.SetData(num_values, data, total_lengths_len); data_ = data + total_lengths_len; @@ -1145,7 +1146,7 @@ class DeltaByteArrayDecoder : public DecoderImpl, virtual void SetData(int num_values, const uint8_t* data, int len) { num_values_ = num_values; if (len == 0) return; - int prefix_len_length = *reinterpret_cast(data); + int prefix_len_length = arrow::util::SafeLoadAs(data); data += 4; len -= 4; prefix_len_decoder_.SetData(num_values, data, prefix_len_length); diff --git a/cpp/src/parquet/file_reader.cc b/cpp/src/parquet/file_reader.cc index 959ea0dfb06..d0ca9ca809d 100644 --- a/cpp/src/parquet/file_reader.cc +++ b/cpp/src/parquet/file_reader.cc @@ -28,6 +28,7 @@ #include "arrow/io/file.h" #include "arrow/status.h" #include "arrow/util/logging.h" +#include "arrow/util/ubsan.h" #include "parquet/column_reader.h" #include "parquet/column_scanner.h" @@ -179,7 +180,7 @@ class SerializedFile : public ParquetFileReader::Contents { throw ParquetException("Invalid parquet file. Corrupt footer."); } - uint32_t metadata_len = *reinterpret_cast( + uint32_t metadata_len = arrow::util::SafeLoadAs( reinterpret_cast(footer_buffer->data()) + footer_read_size - kFooterSize); int64_t metadata_start = file_size - kFooterSize - metadata_len; diff --git a/cpp/src/plasma/common.cc b/cpp/src/plasma/common.cc index 490aa158b33..0f1a0d1b505 100644 --- a/cpp/src/plasma/common.cc +++ b/cpp/src/plasma/common.cc @@ -19,6 +19,8 @@ #include +#include "arrow/util/ubsan.h" + #include "plasma/plasma_generated.h" namespace fb = plasma::flatbuf; @@ -64,7 +66,7 @@ uint64_t MurmurHash64A(const void* key, int len, unsigned int seed) { const uint64_t* end = data + (len / 8); while (data != end) { - uint64_t k = *data++; + uint64_t k = arrow::util::SafeLoad(data++); k *= m; k ^= k >> r; From 7b0335ffb8cae84e401ce62e5d6d199407a8c621 Mon Sep 17 00:00:00 2001 From: Pindikura Ravindra Date: Tue, 2 Jul 2019 18:46:28 +0530 Subject: [PATCH 30/52] ARROW-5818: [Java][Gandiva] support varlen output vectors callback to java for resizing varlen vectors Author: Pindikura Ravindra Closes #4771 from pravindra/jvarlen and squashes the following commits: d9954e865 add check for null expander b710a7792 ARROW-5818: support varlen output vectors --- cpp/src/gandiva/jni/jni_common.cc | 106 +++++++++++++++--- .../arrow/gandiva/evaluator/JniWrapper.java | 3 +- .../arrow/gandiva/evaluator/Projector.java | 22 +++- .../gandiva/evaluator/VectorExpander.java | 69 ++++++++++++ .../gandiva/evaluator/ProjectorTest.java | 2 - 5 files changed, 177 insertions(+), 25 deletions(-) create mode 100644 java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc index 2ff4bc9619a..eeaaca798de 100644 --- a/cpp/src/gandiva/jni/jni_common.cc +++ b/cpp/src/gandiva/jni/jni_common.cc @@ -72,6 +72,11 @@ jclass configuration_builder_class_; // refs for self. static jclass gandiva_exception_; +static jclass vector_expander_class_; +static jclass vector_expander_ret_class_; +static jmethodID vector_expander_method_; +static jfieldID vector_expander_ret_address_; +static jfieldID vector_expander_ret_capacity_; // module maps gandiva::IdToModuleMap> projector_modules_; @@ -91,10 +96,27 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { jclass localExceptionClass = env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException"); gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass); + env->ExceptionDescribe(); env->DeleteLocalRef(localExceptionClass); - env->ExceptionDescribe(); + jclass local_expander_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander"); + vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class); + env->DeleteLocalRef(local_expander_class); + + vector_expander_method_ = env->GetMethodID( + vector_expander_class_, "expandOutputVectorAtIndex", + "(II)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;"); + jclass local_expander_ret_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult"); + vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class); + env->DeleteLocalRef(local_expander_ret_class); + + vector_expander_ret_address_ = + env->GetFieldID(vector_expander_ret_class_, "address", "J"); + vector_expander_ret_capacity_ = + env->GetFieldID(vector_expander_ret_class_, "capacity", "I"); return JNI_VERSION; } @@ -103,6 +125,8 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); env->DeleteGlobalRef(configuration_builder_class_); env->DeleteGlobalRef(gandiva_exception_); + env->DeleteGlobalRef(vector_expander_class_); + env->DeleteGlobalRef(vector_expander_ret_class_); } DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) { @@ -637,27 +661,62 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build /// class JavaResizableBuffer : public arrow::ResizableBuffer { public: - JavaResizableBuffer(uint8_t* buffer, int32_t len) : ResizableBuffer(buffer, len) { + JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer, + int32_t len) + : ResizableBuffer(buffer, len), + env_(env), + jexpander_(jexpander), + vector_idx_(vector_idx) { size_ = 0; } - Status Resize(const int64_t new_size, bool shrink_to_fit) override { - if (shrink_to_fit == true) { - return Status::NotImplemented("shrink not implemented"); - } else if (new_size < capacity()) { - size_ = new_size; - return Status::OK(); - } else { - // TODO: callback into java to re-alloc the buffer. - return Status::NotImplemented("buffer expand not implemented"); - } - } + Status Resize(const int64_t new_size, bool shrink_to_fit) override; Status Reserve(const int64_t new_capacity) override { return Status::NotImplemented("reserve not implemented"); } + + private: + JNIEnv* env_; + jobject jexpander_; + int32_t vector_idx_; }; +Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) { + if (shrink_to_fit == true) { + return Status::NotImplemented("shrink not implemented"); + } + + if (ARROW_PREDICT_TRUE(new_size < capacity())) { + // no need to expand. + size_ = new_size; + return Status::OK(); + } + + if (new_size > INT32_MAX) { + return Status::OutOfMemory("java supports buffer sizes upto 2GB only"); + } + + // callback into java to expand the buffer + int32_t updated_capacity = static_cast(new_size); + jobject ret = env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, + updated_capacity); + if (env_->ExceptionCheck()) { + env_->ExceptionDescribe(); + env_->ExceptionClear(); + return Status::OutOfMemory("buffer expand failed in java"); + } + + jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); + jint ret_capacity = env_->GetIntField(ret, vector_expander_ret_capacity_); + DCHECK_GE(ret_capacity, updated_capacity); + + data_ = mutable_data_ = reinterpret_cast(ret_address); + size_ = new_size; + capacity_ = ret_capacity; + return Status::OK(); +} + #define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \ if (idx >= len) { \ status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \ @@ -666,9 +725,10 @@ class JavaResizableBuffer : public arrow::ResizableBuffer { JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( - JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs, - jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows, jlong sel_vec_addr, - jlong sel_vec_size, jlongArray out_buf_addrs, jlongArray out_buf_sizes) { + JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows, + jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows, + jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs, + jlongArray out_buf_sizes) { Status status; std::shared_ptr holder = projector_modules_.Lookup(module_id); if (holder == nullptr) { @@ -735,6 +795,7 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( ArrayDataVector output; int buf_idx = 0; int sz_idx = 0; + int output_vector_idx = 0; for (FieldPtr field : ret_types) { std::vector> buffers; @@ -755,13 +816,24 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( uint8_t* value_buf = reinterpret_cast(out_bufs[buf_idx++]); jlong data_sz = out_sizes[sz_idx++]; if (arrow::is_binary_like(field->type()->id())) { - buffers.push_back(std::make_shared(value_buf, data_sz)); + if (jexpander == nullptr) { + status = Status::Invalid( + "expression has variable len output columns, but the expander object is " + "null"); + break; + } + buffers.push_back(std::make_shared( + env, jexpander, output_vector_idx, value_buf, data_sz)); } else { buffers.push_back(std::make_shared(value_buf, data_sz)); } auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers); output.push_back(array_data); + ++output_vector_idx; + } + if (!status.ok()) { + break; } status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output); } while (0); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java index ef1d63ae29b..520ef5f443e 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java @@ -48,6 +48,7 @@ native long buildProjector(byte[] schemaBuf, byte[] exprListBuf, * Evaluate the expressions represented by the moduleId on a record batch * and store the output in ValueVectors. Throws an exception in case of errors * + * @param expander VectorExpander object. Used for callbacks from cpp. * @param moduleId moduleId representing expressions. Created using a call to * buildNativeCode * @param numRows Number of rows in the record batch @@ -61,7 +62,7 @@ native long buildProjector(byte[] schemaBuf, byte[] exprListBuf, * @param outSizes The allocated size of the output buffers. On successful evaluation, * the result is stored in the output buffers */ - native void evaluateProjector(long moduleId, int numRows, + native void evaluateProjector(Object expander, long moduleId, int numRows, long[] bufAddrs, long[] bufSizes, int selectionVectorType, int selectionVectorSize, long selectionVectorBufferAddr, long selectionVectorBufferSize, diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java index 93657e6f78f..c15d474a282 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java @@ -27,6 +27,7 @@ import org.apache.arrow.gandiva.expression.ExpressionTree; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.vector.BaseVariableWidthVector; import org.apache.arrow.vector.FixedWidthVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VariableWidthVector; @@ -236,14 +237,18 @@ private void evaluate(int numRows, List buffers, List buf bufSizes[idx++] = bufLayout.getSize(); } + boolean hasVariableWidthColumns = false; + BaseVariableWidthVector[] resizableVectors = new BaseVariableWidthVector[outColumns.size()]; long[] outAddrs = new long[3 * outColumns.size()]; long[] outSizes = new long[3 * outColumns.size()]; idx = 0; + int outColumnIdx = 0; for (ValueVector valueVector : outColumns) { boolean isFixedWith = valueVector instanceof FixedWidthVector; boolean isVarWidth = valueVector instanceof VariableWidthVector; if (!isFixedWith && !isVarWidth) { - throw new UnsupportedTypeException("Unsupported value vector type " + valueVector.getField().getFieldType()); + throw new UnsupportedTypeException( + "Unsupported value vector type " + valueVector.getField().getFieldType()); } outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress(); @@ -251,17 +256,24 @@ private void evaluate(int numRows, List buffers, List buf if (isVarWidth) { outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress(); outSizes[idx++] = valueVector.getOffsetBuffer().capacity(); + hasVariableWidthColumns = true; + + // save vector to allow for resizing. + resizableVectors[outColumnIdx] = (BaseVariableWidthVector)valueVector; } outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); outSizes[idx++] = valueVector.getDataBuffer().capacity(); valueVector.setValueCount(selectionVectorRecordCount); + outColumnIdx++; } - wrapper.evaluateProjector(this.moduleId, numRows, bufAddrs, bufSizes, - selectionVectorType, selectionVectorRecordCount, - selectionVectorAddr, selectionVectorSize, - outAddrs, outSizes); + wrapper.evaluateProjector( + hasVariableWidthColumns ? new VectorExpander(resizableVectors) : null, + this.moduleId, numRows, bufAddrs, bufSizes, + selectionVectorType, selectionVectorRecordCount, + selectionVectorAddr, selectionVectorSize, + outAddrs, outSizes); } /** diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java new file mode 100644 index 00000000000..2414144a853 --- /dev/null +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.BaseVariableWidthVector; + +/** + * This class provides the functionality to expand output vectors using a callback mechanism from + * gandiva. + */ +public class VectorExpander { + private final BaseVariableWidthVector[] vectors; + + public VectorExpander(BaseVariableWidthVector[] vectors) { + this.vectors = vectors; + } + + /** + * Result of vector expansion. + */ + public static class ExpandResult { + public long address; + public int capacity; + + public ExpandResult(long address, int capacity) { + this.address = address; + this.capacity = capacity; + } + } + + /** + * Expand vector at specified index. This is used as a back call from jni, and is only + * relevant for variable width vectors. + * + * @param index index of buffer in the list passed to jni. + * @param toCapacity the size to which the buffer should be expanded to. + * + * @return address and size of the buffer after expansion. + */ + public ExpandResult expandOutputVectorAtIndex(int index, int toCapacity) { + if (index >= vectors.length || vectors[index] == null) { + throw new IllegalArgumentException("invalid index " + index); + } + + BaseVariableWidthVector vector = vectors[index]; + while (vector.getDataBuffer().capacity() < toCapacity) { + vector.reallocDataBuffer(); + } + return new ExpandResult( + vector.getDataBuffer().memoryAddress(), + vector.getDataBuffer().capacity()); + } + +} diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 2fd80910db8..52eeb165a4d 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -580,8 +580,6 @@ public void testStringOutput() throws GandivaException { // test with insufficient data buffer. try { outVector.allocateNew(4, numRows); - thrown.expect(GandivaException.class); - thrown.expectMessage("expand not implemented"); eval.evaluate(batch, output); } finally { releaseRecordBatch(batch); From e70c3a7ade4b865e41ec862d7a04ef5653f99ec1 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Tue, 2 Jul 2019 18:03:15 +0200 Subject: [PATCH 31/52] ARROW-5823: [Rust] CI scripts miss --all-targets cargo argument Rust build breaks because some changes in array builder. However this error is not detected in ci scripts because missing --all-targets in cargo build command. Author: Renjie Liu Closes #4778 from liurenjie1024/fix-build and squashes the following commits: c00f8bce1 Fix build break --- ci/rust-build-main.bat | 2 +- ci/travis_script_rust.sh | 2 +- rust/arrow/benches/arithmetic_kernels.rs | 1 - rust/arrow/benches/array_from_vec.rs | 1 - rust/arrow/benches/boolean_kernels.rs | 1 - rust/arrow/benches/builder.rs | 2 +- rust/arrow/benches/comparison_kernels.rs | 1 - 7 files changed, 3 insertions(+), 7 deletions(-) diff --git a/ci/rust-build-main.bat b/ci/rust-build-main.bat index 5bf1c843928..e7f3c32a549 100644 --- a/ci/rust-build-main.bat +++ b/ci/rust-build-main.bat @@ -29,7 +29,7 @@ pushd rust rustup default nightly rustup show -cargo build --target %TARGET% --release || exit /B +cargo build --target %TARGET% --all-targets --release || exit /B @echo @echo Test (release) @echo -------------- diff --git a/ci/travis_script_rust.sh b/ci/travis_script_rust.sh index c25d64ec42c..704cb37bb06 100755 --- a/ci/travis_script_rust.sh +++ b/ci/travis_script_rust.sh @@ -31,7 +31,7 @@ rustup show # raises on any formatting errors cargo +stable fmt --all -- --check -RUSTFLAGS="-D warnings" cargo build +RUSTFLAGS="-D warnings" cargo build --all-targets cargo test # run examples diff --git a/rust/arrow/benches/arithmetic_kernels.rs b/rust/arrow/benches/arithmetic_kernels.rs index 855355d9f5c..e9851684702 100644 --- a/rust/arrow/benches/arithmetic_kernels.rs +++ b/rust/arrow/benches/arithmetic_kernels.rs @@ -24,7 +24,6 @@ use std::sync::Arc; extern crate arrow; use arrow::array::*; -use arrow::builder::*; use arrow::compute::array_ops::{limit, sum}; use arrow::compute::kernels::arithmetic::*; use arrow::error::Result; diff --git a/rust/arrow/benches/array_from_vec.rs b/rust/arrow/benches/array_from_vec.rs index f9357140922..1918e61c913 100644 --- a/rust/arrow/benches/array_from_vec.rs +++ b/rust/arrow/benches/array_from_vec.rs @@ -22,7 +22,6 @@ use criterion::Criterion; extern crate arrow; use arrow::array::*; -use arrow::array_data::ArrayDataBuilder; use arrow::buffer::Buffer; use arrow::datatypes::*; diff --git a/rust/arrow/benches/boolean_kernels.rs b/rust/arrow/benches/boolean_kernels.rs index d01c9df920a..3a544ace4f5 100644 --- a/rust/arrow/benches/boolean_kernels.rs +++ b/rust/arrow/benches/boolean_kernels.rs @@ -22,7 +22,6 @@ use criterion::Criterion; extern crate arrow; use arrow::array::*; -use arrow::builder::*; use arrow::compute::kernels::boolean as boolean_kernels; use arrow::error::{ArrowError, Result}; diff --git a/rust/arrow/benches/builder.rs b/rust/arrow/benches/builder.rs index 70369804f87..c13874be60a 100644 --- a/rust/arrow/benches/builder.rs +++ b/rust/arrow/benches/builder.rs @@ -25,7 +25,7 @@ use criterion::*; use rand::distributions::Standard; use rand::{thread_rng, Rng}; -use arrow::builder::*; +use arrow::array::*; // Build arrays with 512k elements. const BATCH_SIZE: usize = 8 << 10; diff --git a/rust/arrow/benches/comparison_kernels.rs b/rust/arrow/benches/comparison_kernels.rs index bd75b6ac3f7..77f6d8361d4 100644 --- a/rust/arrow/benches/comparison_kernels.rs +++ b/rust/arrow/benches/comparison_kernels.rs @@ -22,7 +22,6 @@ use criterion::Criterion; extern crate arrow; use arrow::array::*; -use arrow::builder::*; use arrow::compute::*; fn create_array(size: usize) -> Float32Array { From c645a3791448a2498c4b5e6acd6ff70ea493c8fd Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 2 Jul 2019 18:56:40 -0500 Subject: [PATCH 32/52] [Release] Set C++ libraries runtime path to LD_LIBRARY_PATH when running integration tests (#4775) This is also required (and set) when running the unit tests * Set LD_LIBRARY_PATH in integration tests * Code review [skip ci] --- dev/release/verify-release-candidate.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 8b25d304c23..0acb56e4d8a 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -486,7 +486,10 @@ test_integration() { INTEGRATION_TEST_ARGS=--run_flight fi - python integration_test.py $INTEGRATION_TEST_ARGS + # Flight integration test executable have runtime dependency on + # release/libgtest.so + LD_LIBRARY_PATH=$ARROW_CPP_EXE_PATH:$LD_LIBRARY_PATH \ + python integration_test.py $INTEGRATION_TEST_ARGS popd } From 732a17b791de27b3b521ff7561c206054bb139fa Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Tue, 2 Jul 2019 19:16:10 -0500 Subject: [PATCH 33/52] ARROW-5774: [Java][Documentation] Update how to run test git submodule initialization is now necessary to get flight Certs. Author: Micah Kornfield Closes #4737 from emkornfield/update_flight_docs and squashes the following commits: 148429ced ARROW-5774: Update with git command needed to get flight test cert --- java/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/java/README.md b/java/README.md index b19bfdafecf..23575bfd1a3 100644 --- a/java/README.md +++ b/java/README.md @@ -28,6 +28,7 @@ install: ## Building and running tests ``` +git submodule update --init --recursive # Needed for flight cd java mvn install ``` From 913d82d1c21d35d438fc01081dcace901f4def02 Mon Sep 17 00:00:00 2001 From: Pindikura Ravindra Date: Wed, 3 Jul 2019 07:08:06 +0530 Subject: [PATCH 34/52] ARROW-5483: [Java] add ValueVector constructors that take Field object Author: Pindikura Ravindra Closes #4614 from pravindra/vecopt and squashes the following commits: 581399266 ARROW-5483: add ValueVector constructors that take Field object --- .../main/codegen/templates/UnionVector.java | 9 +- .../arrow/vector/BaseFixedWidthVector.java | 18 +- .../apache/arrow/vector/BaseValueVector.java | 10 +- .../arrow/vector/BaseVariableWidthVector.java | 18 +- .../org/apache/arrow/vector/BigIntVector.java | 14 +- .../org/apache/arrow/vector/BitVector.java | 14 +- .../apache/arrow/vector/DateDayVector.java | 14 +- .../apache/arrow/vector/DateMilliVector.java | 14 +- .../apache/arrow/vector/DecimalVector.java | 16 +- .../apache/arrow/vector/DurationVector.java | 17 +- .../arrow/vector/ExtensionTypeVector.java | 25 +- .../arrow/vector/FixedSizeBinaryVector.java | 16 +- .../org/apache/arrow/vector/Float4Vector.java | 14 +- .../org/apache/arrow/vector/Float8Vector.java | 14 +- .../org/apache/arrow/vector/IntVector.java | 14 +- .../arrow/vector/IntervalDayVector.java | 14 +- .../arrow/vector/IntervalYearVector.java | 14 +- .../apache/arrow/vector/SmallIntVector.java | 14 +- .../apache/arrow/vector/TimeMicroVector.java | 14 +- .../apache/arrow/vector/TimeMilliVector.java | 14 +- .../apache/arrow/vector/TimeNanoVector.java | 14 +- .../apache/arrow/vector/TimeSecVector.java | 14 +- .../arrow/vector/TimeStampMicroTZVector.java | 15 ++ .../arrow/vector/TimeStampMicroVector.java | 13 ++ .../arrow/vector/TimeStampMilliTZVector.java | 15 ++ .../arrow/vector/TimeStampMilliVector.java | 13 ++ .../arrow/vector/TimeStampNanoTZVector.java | 15 ++ .../arrow/vector/TimeStampNanoVector.java | 13 ++ .../arrow/vector/TimeStampSecTZVector.java | 15 ++ .../arrow/vector/TimeStampSecVector.java | 13 ++ .../apache/arrow/vector/TimeStampVector.java | 14 +- .../apache/arrow/vector/TinyIntVector.java | 14 +- .../org/apache/arrow/vector/UInt1Vector.java | 7 +- .../org/apache/arrow/vector/UInt2Vector.java | 7 +- .../org/apache/arrow/vector/UInt4Vector.java | 7 +- .../org/apache/arrow/vector/UInt8Vector.java | 7 +- .../apache/arrow/vector/VarBinaryVector.java | 14 +- .../apache/arrow/vector/VarCharVector.java | 14 +- .../complex/BaseRepeatedValueVector.java | 9 +- .../vector/complex/FixedSizeListVector.java | 9 +- .../arrow/vector/complex/ListVector.java | 2 +- .../org/apache/arrow/vector/types/Types.java | 213 ++++++++---------- .../apache/arrow/vector/types/pojo/Field.java | 2 +- .../arrow/vector/types/pojo/FieldType.java | 5 + .../apache/arrow/vector/TestVectorAlloc.java | 106 +++++++++ 45 files changed, 689 insertions(+), 178 deletions(-) create mode 100644 java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 26f4673ee4f..04eed725379 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -17,6 +17,7 @@ import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.vector.types.pojo.FieldType; <@pp.dropOutputFile /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/UnionVector.java" /> @@ -74,14 +75,18 @@ public class UnionVector implements FieldVector { private int singleType = 0; private ValueVector singleVector; - private static final byte TYPE_WIDTH = 1; private final CallBack callBack; private int typeBufferAllocationSizeInBytes; + private static final byte TYPE_WIDTH = 1; + private static final FieldType INTERNAL_STRUCT_TYPE = new FieldType(false /*nullable*/, + ArrowType.Struct.INSTANCE, null /*dictionary*/, null /*metadata*/); + public UnionVector(String name, BufferAllocator allocator, CallBack callBack) { this.name = name; this.allocator = allocator; - this.internalStruct = new NonNullableStructVector("internal", allocator, new FieldType(false, ArrowType.Struct.INSTANCE, null, null), callBack); + this.internalStruct = new NonNullableStructVector("internal", allocator, INTERNAL_STRUCT_TYPE, + callBack); this.typeBuffer = allocator.getEmpty(); this.callBack = callBack; this.typeBufferAllocationSizeInBytes = BaseValueVector.INITIAL_VALUE_ALLOCATION * TYPE_WIDTH; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index 91937caa4db..e4bed23ba4f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -25,7 +25,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.OversizedAllocationException; import org.apache.arrow.vector.util.TransferPair; @@ -52,16 +51,14 @@ public abstract class BaseFixedWidthVector extends BaseValueVector /** * Constructs a new instance. * - * @param name The name of the vector. + * @param field field materialized by this vector * @param allocator The allocator to use for allocating memory for the vector. - * @param fieldType The type of the buffer. * @param typeWidth The width in bytes of the type. */ - public BaseFixedWidthVector(final String name, final BufferAllocator allocator, - FieldType fieldType, final int typeWidth) { - super(name, allocator); + public BaseFixedWidthVector(Field field, final BufferAllocator allocator, final int typeWidth) { + super(allocator); this.typeWidth = typeWidth; - field = new Field(name, fieldType, null); + this.field = field; valueCount = 0; allocationMonitor = 0; validityBuffer = allocator.getEmpty(); @@ -70,6 +67,11 @@ public BaseFixedWidthVector(final String name, final BufferAllocator allocator, } + @Override + public String getName() { + return field.getName(); + } + /* TODO: * see if getNullCount() can be made faster -- O(1) */ @@ -533,7 +535,7 @@ public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallB */ @Override public TransferPair getTransferPair(BufferAllocator allocator) { - return getTransferPair(name, allocator); + return getTransferPair(getName(), allocator); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java index bc12e8e7180..fc8e2e70d22 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java @@ -50,16 +50,16 @@ public abstract class BaseValueVector implements ValueVector { public static final int INITIAL_VALUE_ALLOCATION = 3970; protected final BufferAllocator allocator; - protected final String name; - protected BaseValueVector(String name, BufferAllocator allocator) { + protected BaseValueVector(BufferAllocator allocator) { this.allocator = Preconditions.checkNotNull(allocator, "allocator cannot be null"); - this.name = name; } + public abstract String getName(); + @Override public String toString() { - return super.toString() + "[name = " + name + ", ...]"; + return super.toString() + "[name = " + getName() + ", ...]"; } @Override @@ -73,7 +73,7 @@ public void close() { @Override public TransferPair getTransferPair(BufferAllocator allocator) { - return getTransferPair(name, allocator); + return getTransferPair(getName(), allocator); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index 54913029531..e7fa28964d7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -27,7 +27,6 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.OversizedAllocationException; import org.apache.arrow.vector.util.TransferPair; @@ -58,17 +57,15 @@ public abstract class BaseVariableWidthVector extends BaseValueVector /** * Constructs a new instance. * - * @param name A name for the vector + * @param field The field materialized by this vector. * @param allocator The allocator to use for creating/resizing buffers - * @param fieldType The type of this vector. */ - public BaseVariableWidthVector(final String name, final BufferAllocator allocator, - FieldType fieldType) { - super(name, allocator); + public BaseVariableWidthVector(Field field, final BufferAllocator allocator) { + super(allocator); + this.field = field; lastValueAllocationSizeInBytes = INITIAL_BYTE_COUNT; // -1 because we require one extra slot for the offset array. lastValueCapacity = INITIAL_VALUE_ALLOCATION - 1; - field = new Field(name, fieldType, null); valueCount = 0; lastSet = -1; offsetBuffer = allocator.getEmpty(); @@ -76,6 +73,11 @@ public BaseVariableWidthVector(final String name, final BufferAllocator allocato valueBuffer = allocator.getEmpty(); } + @Override + public String getName() { + return field.getName(); + } + /* TODO: * see if getNullCount() can be made faster -- O(1) */ @@ -656,7 +658,7 @@ public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallB */ @Override public TransferPair getTransferPair(BufferAllocator allocator) { - return getTransferPair(name, allocator); + return getTransferPair(getName(), allocator); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java index 416ffd53fd3..6d235dd1e6c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java @@ -25,6 +25,7 @@ import org.apache.arrow.vector.holders.BigIntHolder; import org.apache.arrow.vector.holders.NullableBigIntHolder; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.TransferPair; @@ -59,7 +60,18 @@ public BigIntVector(String name, BufferAllocator allocator) { * @param allocator allocator for memory management. */ public BigIntVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, TYPE_WIDTH); + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a BigIntVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field field materialized by this vector + * @param allocator allocator for memory management. + */ + public BigIntVector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); reader = new BigIntReaderImpl(BigIntVector.this); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java index ebaca4e72f2..f75ccdc69d0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java @@ -25,6 +25,7 @@ import org.apache.arrow.vector.holders.BitHolder; import org.apache.arrow.vector.holders.NullableBitHolder; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.OversizedAllocationException; import org.apache.arrow.vector.util.TransferPair; @@ -59,7 +60,18 @@ public BitVector(String name, BufferAllocator allocator) { * @param allocator allocator for memory management. */ public BitVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, 0); + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a BitVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field the Field materialized by this vector + * @param allocator allocator for memory management. + */ + public BitVector(Field field, BufferAllocator allocator) { + super(field, allocator,0); reader = new BitReaderImpl(BitVector.this); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java index 1e2b012748c..e634e7ef334 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java @@ -25,6 +25,7 @@ import org.apache.arrow.vector.holders.DateDayHolder; import org.apache.arrow.vector.holders.NullableDateDayHolder; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.TransferPair; @@ -59,7 +60,18 @@ public DateDayVector(String name, BufferAllocator allocator) { * @param allocator allocator for memory management. */ public DateDayVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, TYPE_WIDTH); + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a DateDayVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field Field materialized by this vector + * @param allocator allocator for memory management. + */ + public DateDayVector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); reader = new DateDayReaderImpl(DateDayVector.this); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java index e8ea5be11c9..7ea427dd2a1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java @@ -27,6 +27,7 @@ import org.apache.arrow.vector.holders.DateMilliHolder; import org.apache.arrow.vector.holders.NullableDateMilliHolder; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DateUtility; import org.apache.arrow.vector.util.TransferPair; @@ -62,7 +63,18 @@ public DateMilliVector(String name, BufferAllocator allocator) { * @param allocator allocator for memory management. */ public DateMilliVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, TYPE_WIDTH); + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a DateMilliVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field field materialized by this vector + * @param allocator allocator for memory management. + */ + public DateMilliVector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); reader = new DateMilliReaderImpl(DateMilliVector.this); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 9664bee58f3..4fc35a33126 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.holders.NullableDecimalHolder; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.TransferPair; @@ -68,8 +69,19 @@ public DecimalVector(String name, BufferAllocator allocator, * @param allocator allocator for memory management. */ public DecimalVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, TYPE_WIDTH); - ArrowType.Decimal arrowType = (ArrowType.Decimal) fieldType.getType(); + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a DecimalVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field field materialized by this vector + * @param allocator allocator for memory management. + */ + public DecimalVector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); + ArrowType.Decimal arrowType = (ArrowType.Decimal) field.getFieldType().getType(); reader = new DecimalReaderImpl(DecimalVector.this); this.precision = arrowType.getPrecision(); this.scale = arrowType.getScale(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java index 312c8e51309..76572bc9553 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.TransferPair; @@ -54,10 +55,20 @@ public class DurationVector extends BaseFixedWidthVector { * @param allocator allocator for memory management. */ public DurationVector(String name, FieldType fieldType, BufferAllocator allocator) { - super(name, allocator, fieldType, TYPE_WIDTH); - reader = new DurationReaderImpl(DurationVector.this); - this.unit = ((ArrowType.Duration)fieldType.getType()).getUnit(); + this(new Field(name, fieldType, null), allocator); + } + /** + * Instantiate a DurationVector. This doesn't allocate any memory for + * the data in vector. + * + * @param field field materialized by this vector + * @param allocator allocator for memory management. + */ + public DurationVector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); + reader = new DurationReaderImpl(DurationVector.this); + this.unit = ((ArrowType.Duration)field.getFieldType().getType()).getUnit(); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java index 9594d9e5814..14a66f8dafa 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java @@ -39,12 +39,35 @@ public abstract class ExtensionTypeVector children) { * Construct a new vector of this type using the given allocator. */ public FieldVector createVector(BufferAllocator allocator) { - FieldVector vector = fieldType.createNewSingleVector(name, allocator, null); + FieldVector vector = fieldType.createNewSingleVector(this, allocator, null); vector.initializeChildrenFromFields(children); return vector; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java index 4cc4067c997..945f5df2d98 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java @@ -98,4 +98,9 @@ public FieldVector createNewSingleVector(String name, BufferAllocator allocator, return minorType.getNewVector(name, this, allocator, schemaCallBack); } + public FieldVector createNewSingleVector(Field field, BufferAllocator allocator, CallBack schemaCallBack) { + MinorType minorType = Types.getMinorTypeForArrowType(type); + return minorType.getNewVector(field, allocator, schemaCallBack); + } + } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java new file mode 100644 index 00000000000..089f1f84ff8 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.Collections; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; +import org.apache.arrow.vector.types.pojo.ArrowType.Duration; +import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeBinary; +import org.apache.arrow.vector.types.pojo.ArrowType.Timestamp; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestVectorAlloc { + private BufferAllocator rootAllocator; + + @Before + public void init() { + rootAllocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void terminate() throws Exception { + rootAllocator.close(); + } + + private static Field field(String name, ArrowType type) { + return new Field(name, new FieldType(true, type, null), Collections.emptyList()); + } + + @Test + public void testVectorAllocWithField() { + Schema schema = new Schema(Arrays.asList( + field("TINYINT", MinorType.TINYINT.getType()), + field("SMALLINT", MinorType.SMALLINT.getType()), + field("INT", MinorType.INT.getType()), + field("BIGINT", MinorType.BIGINT.getType()), + field("UINT1", MinorType.UINT1.getType()), + field("UINT2", MinorType.UINT2.getType()), + field("UINT4", MinorType.UINT4.getType()), + field("UINT8", MinorType.UINT8.getType()), + field("FLOAT4", MinorType.FLOAT4.getType()), + field("FLOAT8", MinorType.FLOAT8.getType()), + field("UTF8", MinorType.VARCHAR.getType()), + field("VARBINARY", MinorType.VARBINARY.getType()), + field("BIT", MinorType.BIT.getType()), + field("DECIMAL", new Decimal(38, 5)), + field("FIXEDSIZEBINARY", new FixedSizeBinary(50)), + field("DATEDAY", MinorType.DATEDAY.getType()), + field("DATEMILLI", MinorType.DATEMILLI.getType()), + field("TIMESEC", MinorType.TIMESEC.getType()), + field("TIMEMILLI", MinorType.TIMEMILLI.getType()), + field("TIMEMICRO", MinorType.TIMEMICRO.getType()), + field("TIMENANO", MinorType.TIMENANO.getType()), + field("TIMESTAMPSEC", MinorType.TIMESTAMPSEC.getType()), + field("TIMESTAMPMILLI", MinorType.TIMESTAMPMILLI.getType()), + field("TIMESTAMPMICRO", MinorType.TIMESTAMPMICRO.getType()), + field("TIMESTAMPNANO", MinorType.TIMESTAMPNANO.getType()), + field("TIMESTAMPSECTZ", new Timestamp(TimeUnit.SECOND, "PST")), + field("TIMESTAMPMILLITZ", new Timestamp(TimeUnit.MILLISECOND, "PST")), + field("TIMESTAMPMICROTZ", new Timestamp(TimeUnit.MICROSECOND, "PST")), + field("TIMESTAMPNANOTZ", new Timestamp(TimeUnit.NANOSECOND, "PST")), + field("INTERVALDAY", MinorType.INTERVALDAY.getType()), + field("INTERVALYEAR", MinorType.INTERVALYEAR.getType()), + field("DURATION", new Duration(TimeUnit.MILLISECOND)) + )); + + try (BufferAllocator allocator = rootAllocator.newChildAllocator("child", 0, Long.MAX_VALUE)) { + for (Field field : schema.getFields()) { + try (FieldVector vector = field.createVector(allocator)) { + assertEquals(vector.getMinorType(), + Types.getMinorTypeForArrowType(field.getFieldType().getType())); + vector.allocateNew(); + } + } + } + } +} From 07b550c222a54f219038880c141a595a7507e73f Mon Sep 17 00:00:00 2001 From: tianchen Date: Tue, 2 Jul 2019 21:13:04 -0700 Subject: [PATCH 35/52] ARROW-5814: [Java] Implement a HashMap for DictionaryEncoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Related to [ARROW-5814](https://issues.apache.org/jira/browse/ARROW-5814). As a follow-up of https://github.com/apache/arrow/pull/4698. Implement a Map for DictionaryEncoder to reduce boxing/unboxing operations. Benchmark: DictionaryEncodeHashMapBenchmarks.testHashMap: avgt 5 31151.345 ± 1661.878 ns/op DictionaryEncodeHashMapBenchmarks.testDictionaryEncodeHashMap: avgt 5 15549.902 ± 771.647 ns/op Author: tianchen Closes #4765 from tianchen92/map and squashes the following commits: 38ee5a4af add UT f62003337 add javadoc and change package 10596ad87 fix style 86eb350b3 add interface 98f4c5593 init --- .../DictionaryEncodeHashMapBenchmarks.java | 117 ++++++ .../dictionary/DictionaryEncodeHashMap.java | 368 ++++++++++++++++++ .../arrow/vector/dictionary/ObjectIntMap.java | 50 +++ .../vector/TestDictionaryEncodeHashMap.java | 123 ++++++ 4 files changed, 658 insertions(+) create mode 100644 java/performance/src/test/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMapBenchmarks.java create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMap.java create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/dictionary/ObjectIntMap.java create mode 100644 java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodeHashMap.java diff --git a/java/performance/src/test/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMapBenchmarks.java b/java/performance/src/test/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMapBenchmarks.java new file mode 100644 index 00000000000..e97bff2e1dd --- /dev/null +++ b/java/performance/src/test/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMapBenchmarks.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Benchmarks for {@link DictionaryEncodeHashMap}. + */ +@State(Scope.Benchmark) +public class DictionaryEncodeHashMapBenchmarks { + private static final int SIZE = 1000; + + private static final int KEY_LENGTH = 10; + + private List testData = new ArrayList<>(); + + private HashMap hashMap = new HashMap(); + private DictionaryEncodeHashMap dictionaryEncodeHashMap = new DictionaryEncodeHashMap(); + + /** + * Setup benchmarks. + */ + @Setup + public void prepare() { + for (int i = 0; i < SIZE; i++) { + testData.add(getRandomString(KEY_LENGTH)); + } + } + + private String getRandomString(int length) { + String str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + Random random = new Random(); + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < length; i++) { + int number = random.nextInt(62); + sb.append(str.charAt(number)); + } + return sb.toString(); + } + + /** + * Test set/get int values for {@link HashMap}. + * @return useless. To avoid DCE by JIT. + */ + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public int testHashMap() { + for (int i = 0; i < SIZE; i++) { + hashMap.put(testData.get(i), i); + } + for (int i = 0; i < SIZE; i++) { + hashMap.get(testData.get(i)); + } + return 0; + } + + /** + * Test set/get int values for {@link DictionaryEncodeHashMap}. + * @return useless. To avoid DCE by JIT. + */ + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public int testDictionaryEncodeHashMap() { + for (int i = 0; i < SIZE; i++) { + dictionaryEncodeHashMap.put(testData.get(i), i); + } + for (int i = 0; i < SIZE; i++) { + dictionaryEncodeHashMap.get(testData.get(i)); + } + return 0; + } + + @Test + public void evaluate() throws RunnerException { + Options opt = new OptionsBuilder() + .include(DictionaryEncodeHashMapBenchmarks.class.getSimpleName()) + .forks(1) + .build(); + + new Runner(opt).run(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMap.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMap.java new file mode 100644 index 00000000000..659a8d6b634 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncodeHashMap.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.util.Objects; + +/** + * Implementation of the {@link ObjectIntMap} interface, used for DictionaryEncoder. + * Note that value in this map is always not less than 0, and -1 represents a null value. + * @param key type. + */ +public class DictionaryEncodeHashMap implements ObjectIntMap { + + /** + * Represents a null value in map. + */ + static final int NULL_VALUE = -1; + + /** + * The default initial capacity - MUST be a power of two. + */ + static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; + + /** + * The maximum capacity, used if a higher value is implicitly specified + * by either of the constructors with arguments. + */ + static final int MAXIMUM_CAPACITY = 1 << 30; + + /** + * The load factor used when none specified in constructor. + */ + static final float DEFAULT_LOAD_FACTOR = 0.75f; + + static final Entry[] EMPTY_TABLE = {}; + + /** + * The table, initialized on first use, and resized as + * necessary. When allocated, length is always a power of two. + */ + transient Entry[] table = (Entry[]) EMPTY_TABLE; + + /** + * The number of key-value mappings contained in this map. + */ + transient int size; + + /** + * The next size value at which to resize (capacity * load factor). + */ + int threshold; + + /** + * The load factor for the hash table. + */ + final float loadFactor; + + /** + * Constructs an empty map with the specified initial capacity and load factor. + */ + public DictionaryEncodeHashMap(int initialCapacity, float loadFactor) { + if (initialCapacity < 0) { + throw new IllegalArgumentException("Illegal initial capacity: " + + initialCapacity); + } + if (initialCapacity > MAXIMUM_CAPACITY) { + initialCapacity = MAXIMUM_CAPACITY; + } + if (loadFactor <= 0 || Float.isNaN(loadFactor)) { + throw new IllegalArgumentException("Illegal load factor: " + + loadFactor); + } + this.loadFactor = loadFactor; + this.threshold = initialCapacity; + } + + public DictionaryEncodeHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public DictionaryEncodeHashMap() { + this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + /** + * Compute the capacity with given threshold and create init table. + */ + private void inflateTable(int threshold) { + int capacity = roundUpToPowerOf2(threshold); + this.threshold = (int) Math.min(capacity * loadFactor, MAXIMUM_CAPACITY + 1); + table = new Entry[capacity]; + } + + /** + * Computes key.hashCode() and spreads (XORs) higher bits of hash to lower. + */ + static final int hash(Object key) { + int h; + return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16); + } + + /** + * Computes the storage location in an array for the given hashCode. + */ + static int indexFor(int h, int length) { + return h & (length - 1); + } + + /** + * Returns a power of two size for the given size. + */ + static final int roundUpToPowerOf2(int size) { + int n = size - 1; + n |= n >>> 1; + n |= n >>> 2; + n |= n >>> 4; + n |= n >>> 8; + n |= n >>> 16; + return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1; + } + + /** + * Returns the value to which the specified key is mapped, + * or -1 if this map contains no mapping for the key. + */ + @Override + public int get(K key) { + int hash = hash(key); + int index = indexFor(hash, table.length); + for (Entry e = table[index] ; e != null ; e = e.next) { + if ((e.hash == hash) && e.key.equals(key)) { + return e.value; + } + } + return NULL_VALUE; + } + + /** + * Associates the specified value with the specified key in this map. + * If the map previously contained a mapping for the key, the old + * value is replaced. + */ + @Override + public int put(K key, int value) { + if (table == EMPTY_TABLE) { + inflateTable(threshold); + } + + if (key == null) { + return putForNullKey(value); + } + + int hash = hash(key); + int i = indexFor(hash, table.length); + for (Entry e = table[i]; e != null; e = e.next) { + Object k; + if (e.hash == hash && ((k = e.key) == key || key.equals(k))) { + int oldValue = e.value; + e.value = value; + return oldValue; + } + } + + addEntry(hash, key, value, i); + return NULL_VALUE; + } + + /** + * Removes the mapping for the specified key from this map if present. + * @param key key whose mapping is to be removed from the map + * @return the previous value associated with key, or + * -1 if there was no mapping for key. + */ + @Override + public int remove(K key) { + Entry e = removeEntryForKey(key); + return (e == null ? NULL_VALUE : e.value); + } + + /** + * Create a new Entry at the specific position of table. + */ + void createEntry(int hash, K key, int value, int bucketIndex) { + Entry e = table[bucketIndex]; + table[bucketIndex] = new Entry<>(hash, key, value, e); + size++; + } + + /** + * Put value when key is null. + */ + private int putForNullKey(int value) { + for (Entry e = table[0]; e != null; e = e.next) { + if (e.key == null) { + int oldValue = e.value; + e.value = value; + return oldValue; + } + } + addEntry(0, null, value, 0); + return NULL_VALUE; + } + + /** + * Add Entry at the specified location of the table. + */ + void addEntry(int hash, K key, int value, int bucketIndex) { + if ((size >= threshold) && (null != table[bucketIndex])) { + resize(2 * table.length); + hash = (null != key) ? hash(key) : 0; + bucketIndex = indexFor(hash, table.length); + } + + createEntry(hash, key, value, bucketIndex); + } + + /** + * Resize table with given new capacity. + */ + void resize(int newCapacity) { + Entry[] oldTable = table; + int oldCapacity = oldTable.length; + if (oldCapacity == MAXIMUM_CAPACITY) { + threshold = Integer.MAX_VALUE; + return; + } + + Entry[] newTable = new Entry[newCapacity]; + transfer(newTable); + table = newTable; + threshold = (int)Math.min(newCapacity * loadFactor, MAXIMUM_CAPACITY + 1); + } + + /** + * Transfer entries into new table from old table. + * @param newTable new table + */ + void transfer(Entry[] newTable) { + int newCapacity = newTable.length; + for (Entry e : table) { + while (null != e) { + Entry next = e.next; + int i = indexFor(e.hash, newCapacity); + e.next = newTable[i]; + newTable[i] = e; + e = next; + } + } + } + + /** + * Remove entry associated with the given key. + */ + final Entry removeEntryForKey(Object key) { + if (size == 0) { + return null; + } + int hash = (key == null) ? 0 : hash(key); + int i = indexFor(hash, table.length); + Entry prev = table[i]; + Entry e = prev; + + while (e != null) { + Entry next = e.next; + Object k; + if (e.hash == hash && ((k = e.key) == key || (key != null && key.equals(k)))) { + size--; + if (prev == e) { + table[i] = next; + } else { + prev.next = next; + } + + return e; + } + prev = e; + e = next; + } + + return e; + } + + /** + * Returns the number of mappings in this Map. + */ + public int size() { + return size; + } + + /** + * Removes all elements from this map, leaving it empty. + */ + public void clear() { + size = 0; + for (int i = 0; i < table.length; i++) { + table[i] = null; + } + } + + /** + * Class to keep key-value data within hash map. + * @param key type. + */ + static class Entry { + final K key; + int value; + Entry next; + int hash; + + Entry(int hash, K key, int value, Entry next) { + this.key = key; + this.value = value; + this.hash = hash; + this.next = next; + } + + public final K getKey() { + return key; + } + + public final int getValue() { + return value; + } + + public final int setValue(int newValue) { + int oldValue = value; + value = newValue; + return oldValue; + } + + public final boolean equals(Object o) { + if (!(o instanceof DictionaryEncodeHashMap.Entry)) { + return false; + } + Entry e = (Entry) o; + if (Objects.equals(key, e.getKey())) { + if (value == e.getValue()) { + return true; + } + } + return false; + } + + public final int hashCode() { + return Objects.hashCode(key) ^ Objects.hashCode(value); + } + + public final String toString() { + return getKey() + "=" + getValue(); + } + } + +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ObjectIntMap.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ObjectIntMap.java new file mode 100644 index 00000000000..582bb561cb7 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ObjectIntMap.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +/** + * Specific hash map for int type value, reducing boxing/unboxing operations. + * @param key type. + */ +public interface ObjectIntMap { + + /** + * Associates the specified value with the specified key in this map. + * If the map previously contained a mapping for the key, the old + * value is replaced. + * @param key key with which the specified value is to be associated + * @param value value to be associated with the specified key + * @return the previous value associated with key, or + * -1 if there was no mapping for key. + */ + int put(K key, int value); + + /** + * Returns the value to which the specified key is mapped, + * or -1 if this map contains no mapping for the key. + */ + int get(K key); + + /** + * Removes the mapping for the specified key from this map if present. + * @param key key whose mapping is to be removed from the map + * @return the previous value associated with key, or + * -1 if there was no mapping for key. + */ + int remove(K key); +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodeHashMap.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodeHashMap.java new file mode 100644 index 00000000000..5f4e710b8af --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodeHashMap.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import org.apache.arrow.vector.dictionary.DictionaryEncodeHashMap; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + + + +public class TestDictionaryEncodeHashMap { + + private List testData = new ArrayList<>(); + + private static final int SIZE = 100; + + private static final int KEY_LENGTH = 5; + + private DictionaryEncodeHashMap map = new DictionaryEncodeHashMap(); + + @Before + public void init() { + for (int i = 0; i < SIZE; i++) { + testData.add(generateUniqueKey(KEY_LENGTH)); + } + } + + @After + public void terminate() throws Exception { + testData.clear(); + } + + private String generateUniqueKey(int length) { + String str = "abcdefghijklmnopqrstuvwxyz"; + Random random = new Random(); + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < length; i++) { + int number = random.nextInt(26); + sb.append(str.charAt(number)); + } + if (testData.contains(sb.toString())) { + return generateUniqueKey(length); + } + return sb.toString(); + } + + @Test + public void testPutAndGet() { + for (int i = 0; i < SIZE; i++) { + map.put(testData.get(i), i); + } + + for (int i = 0; i < SIZE; i++) { + assertEquals(i, map.get(testData.get(i))); + } + } + + @Test + public void testPutExistKey() { + for (int i = 0; i < SIZE; i++) { + map.put(testData.get(i), i); + } + map.put("test_key", 101); + assertEquals(101, map.get("test_key")); + map.put("test_key", 102); + assertEquals(102, map.get("test_key")); + } + + @Test + public void testGetNonExistKey() { + for (int i = 0; i < SIZE; i++) { + map.put(testData.get(i), i); + } + //remove if exists. + map.remove("test_key"); + assertEquals(-1, map.get("test_key")); + } + + @Test + public void testRemove() { + for (int i = 0; i < SIZE; i++) { + map.put(testData.get(i), i); + } + map.put("test_key", 10000); + assertEquals(SIZE + 1, map.size()); + assertEquals(10000, map.get("test_key")); + map.remove("test_key"); + assertEquals(SIZE, map.size()); + assertEquals(-1, map.get("test_key")); + } + + @Test + public void testSize() { + assertEquals(0, map.size()); + for (int i = 0; i < SIZE; i++) { + map.put(testData.get(i), i); + } + assertEquals(SIZE, map.size()); + } +} From e28d3662a008b7a4a2ebbc5b3a6265c13970e742 Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Tue, 2 Jul 2019 21:15:37 -0700 Subject: [PATCH 36/52] ARROW-5778: [Java] Extract the logic for vector data copying to the super classes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, each vector has its own copyFrom method. The implementations for fixed-width vectors are similar, whereas the implementations for the variable-width vectors are similar. This issue extract such implementations to the base classes, with the following benefits: 1. Less code makes it easier to maintain 2. Move the method to the base class makes the method more convenient to use. Benchmark evaluations show that there are minor performance improvements, but the overall performance remain almost identical: Fixed width data type (float8) before: Float8Benchmarks.copyFromBenchmark avgt 4 17.809 ± 0.022 us/op after: Float8Benchmarks.copyFromBenchmark avgt 4 17.712 ± 0.050 us/op Variable width data type (varchar) before: VarCharBenchmarks.copyFromBenchmark avgt 5 36.455 ± 0.125 us/op after: VarCharBenchmarks.copyFromBenchmark avgt 5 33.015 ± 0.065 us/op Author: liyafan82 Closes #4764 from liyafan82/fly_5778 and squashes the following commits: 54193917e Merge branch 'master' into fly_5778 1612c9e8f Extract the logic for vector data copying to the super classes --- .../apache/arrow/vector/Float8Benchmarks.java | 24 ++++ .../arrow/vector/VarCharBenchmarks.java | 104 ++++++++++++++++++ .../arrow/vector/BaseFixedWidthVector.java | 33 ++++++ .../arrow/vector/BaseVariableWidthVector.java | 57 ++++++++++ .../org/apache/arrow/vector/BigIntVector.java | 29 ----- .../org/apache/arrow/vector/BitVector.java | 20 +--- .../apache/arrow/vector/DateDayVector.java | 29 ----- .../apache/arrow/vector/DateMilliVector.java | 29 ----- .../apache/arrow/vector/DecimalVector.java | 28 ----- .../apache/arrow/vector/DurationVector.java | 30 ----- .../arrow/vector/FixedSizeBinaryVector.java | 28 ----- .../org/apache/arrow/vector/Float4Vector.java | 29 ----- .../org/apache/arrow/vector/Float8Vector.java | 28 ----- .../org/apache/arrow/vector/IntVector.java | 29 ----- .../arrow/vector/IntervalDayVector.java | 29 ----- .../arrow/vector/IntervalYearVector.java | 29 ----- .../apache/arrow/vector/SmallIntVector.java | 29 ----- .../apache/arrow/vector/TimeMicroVector.java | 29 ----- .../apache/arrow/vector/TimeMilliVector.java | 28 ----- .../apache/arrow/vector/TimeNanoVector.java | 28 ----- .../apache/arrow/vector/TimeSecVector.java | 28 ----- .../apache/arrow/vector/TimeStampVector.java | 28 ----- .../apache/arrow/vector/TinyIntVector.java | 28 ----- .../org/apache/arrow/vector/UInt1Vector.java | 18 --- .../org/apache/arrow/vector/UInt2Vector.java | 16 --- .../org/apache/arrow/vector/UInt4Vector.java | 19 ---- .../org/apache/arrow/vector/UInt8Vector.java | 19 ---- .../apache/arrow/vector/VarBinaryVector.java | 42 ------- .../apache/arrow/vector/VarCharVector.java | 42 ------- 29 files changed, 221 insertions(+), 688 deletions(-) create mode 100644 java/performance/src/test/java/org/apache/arrow/vector/VarCharBenchmarks.java diff --git a/java/performance/src/test/java/org/apache/arrow/vector/Float8Benchmarks.java b/java/performance/src/test/java/org/apache/arrow/vector/Float8Benchmarks.java index 9ab6e375eaf..4617f5bf9bc 100644 --- a/java/performance/src/test/java/org/apache/arrow/vector/Float8Benchmarks.java +++ b/java/performance/src/test/java/org/apache/arrow/vector/Float8Benchmarks.java @@ -50,6 +50,8 @@ public class Float8Benchmarks { private Float8Vector vector; + private Float8Vector fromVector; + /** * Setup benchmarks. */ @@ -58,6 +60,18 @@ public void prepare() { allocator = new RootAllocator(ALLOCATOR_CAPACITY); vector = new Float8Vector("vector", allocator); vector.allocateNew(VECTOR_LENGTH); + + fromVector = new Float8Vector("vector", allocator); + fromVector.allocateNew(VECTOR_LENGTH); + + for (int i = 0;i < VECTOR_LENGTH; i++) { + if (i % 3 == 0) { + fromVector.setNull(i); + } else { + fromVector.set(i, i * i); + } + } + fromVector.setValueCount(VECTOR_LENGTH); } /** @@ -66,6 +80,7 @@ public void prepare() { @TearDown public void tearDown() { vector.close(); + fromVector.close(); allocator.close(); } @@ -88,6 +103,15 @@ public double readWriteBenchmark() { return sum; } + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MICROSECONDS) + public void copyFromBenchmark() { + for (int i = 0; i < VECTOR_LENGTH; i++) { + vector.copyFrom(i, i, (Float8Vector) fromVector); + } + } + @Test public void evaluate() throws RunnerException { Options opt = new OptionsBuilder() diff --git a/java/performance/src/test/java/org/apache/arrow/vector/VarCharBenchmarks.java b/java/performance/src/test/java/org/apache/arrow/vector/VarCharBenchmarks.java new file mode 100644 index 00000000000..39ff9c05a35 --- /dev/null +++ b/java/performance/src/test/java/org/apache/arrow/vector/VarCharBenchmarks.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Benchmarks for {@link VarCharVector}. + */ +@State(Scope.Benchmark) +public class VarCharBenchmarks { + + private static final int VECTOR_LENGTH = 1024; + + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + + private BufferAllocator allocator; + + private VarCharVector vector; + + private VarCharVector fromVector; + + /** + * Setup benchmarks. + */ + @Setup + public void prepare() { + allocator = new RootAllocator(ALLOCATOR_CAPACITY); + vector = new VarCharVector("vector", allocator); + vector.allocateNew(ALLOCATOR_CAPACITY / 4, VECTOR_LENGTH); + + fromVector = new VarCharVector("vector", allocator); + fromVector.allocateNew(ALLOCATOR_CAPACITY / 4, VECTOR_LENGTH); + + for (int i = 0;i < VECTOR_LENGTH; i++) { + if (i % 3 == 0) { + fromVector.setNull(i); + } else { + fromVector.set(i, String.valueOf(i * 1000).getBytes()); + } + } + fromVector.setValueCount(VECTOR_LENGTH); + } + + /** + * Tear down benchmarks. + */ + @TearDown + public void tearDown() { + vector.close(); + fromVector.close(); + allocator.close(); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MICROSECONDS) + public void copyFromBenchmark() { + for (int i = 0; i < VECTOR_LENGTH; i++) { + vector.copyFrom(i, i, fromVector); + } + } + + @Test + public void evaluate() throws RunnerException { + Options opt = new OptionsBuilder() + .include(VarCharBenchmarks.class.getSimpleName()) + .forks(1) + .build(); + + new Runner(opt).run(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index e4bed23ba4f..8feca751adf 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.util.TransferPair; import io.netty.buffer.ArrowBuf; +import io.netty.util.internal.PlatformDependent; /** * BaseFixedWidthVector provides an abstract interface for @@ -804,4 +805,36 @@ protected void handleSafe(int index) { reAlloc(); } } + + /** + * Copy a cell value from a particular index in source vector to a particular + * position in this vector. + * + * @param fromIndex position to copy from in source vector + * @param thisIndex position to copy to in this vector + * @param from source vector + */ + public void copyFrom(int fromIndex, int thisIndex, BaseFixedWidthVector from) { + if (from.isNull(fromIndex)) { + BitVectorHelper.setValidityBit(this.getValidityBuffer(), thisIndex, 0); + } else { + BitVectorHelper.setValidityBit(this.getValidityBuffer(), thisIndex, 1); + PlatformDependent.copyMemory(from.getDataBuffer().memoryAddress() + fromIndex * typeWidth, + this.getDataBuffer().memoryAddress() + thisIndex * typeWidth, typeWidth); + } + } + + /** + * Same as {@link #copyFrom(int, int, BaseFixedWidthVector)} except that + * it handles the case when the capacity of the vector needs to be expanded + * before copy. + * + * @param fromIndex position to copy from in source vector + * @param thisIndex position to copy to in this vector + * @param from source vector + */ + public void copyFromSafe(int fromIndex, int thisIndex, BaseFixedWidthVector from) { + handleSafe(thisIndex); + copyFrom(fromIndex, thisIndex, from); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index e7fa28964d7..5262f339e22 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -1277,4 +1277,61 @@ public static ArrowBuf set(ArrowBuf buffer, BufferAllocator allocator, return buffer; } + + /** + * Copy a cell value from a particular index in source vector to a particular + * position in this vector. + * + * @param fromIndex position to copy from in source vector + * @param thisIndex position to copy to in this vector + * @param from source vector + */ + public void copyFrom(int fromIndex, int thisIndex, BaseVariableWidthVector from) { + if (from.isNull(fromIndex)) { + fillHoles(thisIndex); + BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, 0); + final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); + offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart); + } else { + final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); + final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); + final int length = end - start; + fillHoles(thisIndex); + BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, 1); + final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); + from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); + offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); + } + lastSet = thisIndex; + } + + /** + * Same as {@link #copyFrom(int, int, BaseVariableWidthVector)} except that + * it handles the case when the capacity of the vector needs to be expanded + * before copy. + * + * @param fromIndex position to copy from in source vector + * @param thisIndex position to copy to in this vector + * @param from source vector + */ + public void copyFromSafe(int fromIndex, int thisIndex, BaseVariableWidthVector from) { + if (from.isNull(fromIndex)) { + handleSafe(thisIndex, 0); + fillHoles(thisIndex); + BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, 0); + final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); + offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart); + } else { + final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); + final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); + final int length = end - start; + handleSafe(thisIndex, length); + fillHoles(thisIndex); + BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, 1); + final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); + from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); + offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); + } + lastSet = thisIndex; + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java index 6d235dd1e6c..1001cec6679 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java @@ -145,35 +145,6 @@ public Long getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, BigIntVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, BigIntVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, BigIntVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java index f75ccdc69d0..ff4504fe958 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java @@ -296,26 +296,12 @@ public Boolean getObject(int index) { * @param thisIndex position to copy to in this vector * @param from source vector */ - public void copyFrom(int fromIndex, int thisIndex, BitVector from) { + @Override + public void copyFrom(int fromIndex, int thisIndex, BaseFixedWidthVector from) { BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - BitVectorHelper.setValidityBit(valueBuffer, thisIndex, from.getBit(fromIndex)); - } - - /** - * Same as {@link #copyFrom(int, int, BitVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, BitVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); + BitVectorHelper.setValidityBit(valueBuffer, thisIndex, ((BitVector) from).getBit(fromIndex)); } - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java index e634e7ef334..72af5def9fc 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java @@ -146,35 +146,6 @@ public Integer getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, DateDayVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, DateDayVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, DateDayVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java index 7ea427dd2a1..be4fcbe26e9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java @@ -150,35 +150,6 @@ public LocalDateTime getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, DateMilliVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, DateMilliVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, DateMilliVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 4fc35a33126..1db83a1ae2f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -162,34 +162,6 @@ public BigDecimal getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, DecimalVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - from.valueBuffer.getBytes(fromIndex * TYPE_WIDTH, valueBuffer, - thisIndex * TYPE_WIDTH, TYPE_WIDTH); - } - - /** - * Same as {@link #copyFrom(int, int, DecimalVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, DecimalVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /** * Return scale for the decimal value. */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java index 76572bc9553..92a9e7044b8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java @@ -193,42 +193,12 @@ private StringBuilder getAsStringBuilderHelper(int index) { return new StringBuilder(getObject(index).toString()); } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, DurationVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - from.valueBuffer.getBytes(fromIndex * TYPE_WIDTH, this.valueBuffer, - thisIndex * TYPE_WIDTH, TYPE_WIDTH); - } - - /** - * Same as {@link #copyFrom(int, int, DurationVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, DurationVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | | | *----------------------------------------------------------------*/ - /** * Set the element at the given index to the given value. * diff --git a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java index 50179de2672..61bd57c1356 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java @@ -158,34 +158,6 @@ public byte[] getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, FixedSizeBinaryVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - from.valueBuffer.getBytes(fromIndex * byteWidth, valueBuffer, - thisIndex * byteWidth, byteWidth); - } - - /** - * Same as {@link #copyFrom(int, int, FixedSizeBinaryVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, FixedSizeBinaryVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - public int getByteWidth() { return byteWidth; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java index 7976b029b0d..96b5625c54a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java @@ -147,35 +147,6 @@ public Float getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, Float4Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final float value = from.valueBuffer.getFloat(fromIndex * TYPE_WIDTH); - valueBuffer.setFloat(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, Float4Vector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, Float4Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java index 109bf5d556c..24128cdbb88 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java @@ -147,34 +147,6 @@ public Double getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, Float8Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final double value = from.valueBuffer.getDouble(fromIndex * TYPE_WIDTH); - valueBuffer.setDouble(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, Float8Vector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, Float8Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java index 43e5f2c4ab7..4ce5e1b92f8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java @@ -147,35 +147,6 @@ public Integer getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, IntVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, IntVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, IntVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java index 1bdbd48a8b4..0d7125b7a0e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java @@ -223,35 +223,6 @@ private StringBuilder getAsStringBuilderHelper(int index) { .append(millis)); } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, IntervalDayVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - from.valueBuffer.getBytes(fromIndex * TYPE_WIDTH, this.valueBuffer, - thisIndex * TYPE_WIDTH, TYPE_WIDTH); - } - - /** - * Same as {@link #copyFrom(int, int, IntervalDayVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, IntervalDayVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java index 65e3b59052a..2b73d02504d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java @@ -196,35 +196,6 @@ private StringBuilder getAsStringBuilderHelper(int index) { .append(monthString)); } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, IntervalYearVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, IntervalYearVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, IntervalYearVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java index 1a06bcd77e3..165b774c8dc 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java @@ -148,35 +148,6 @@ public Short getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, SmallIntVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final short value = from.valueBuffer.getShort(fromIndex * TYPE_WIDTH); - valueBuffer.setShort(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, SmallIntVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, SmallIntVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java index 0f58babc825..089164c0016 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java @@ -147,35 +147,6 @@ public Long getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TimeMicroVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, TimeMicroVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TimeMicroVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - - /*----------------------------------------------------------------* | | | vector value setter methods | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java index 5552353d0d4..9f41c84527d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java @@ -150,34 +150,6 @@ public LocalDateTime getObject(int index) { return DateUtility.getLocalDateTimeFromEpochMilli(millis); } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TimeMilliVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, TimeMilliVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TimeMilliVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java index 2c059522364..053a722430f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java @@ -147,34 +147,6 @@ public Long getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TimeNanoVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, TimeNanoVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TimeNanoVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java index 02e9dcb76ba..15992af79d6 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java @@ -147,34 +147,6 @@ public Integer getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TimeSecVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, TimeSecVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TimeSecVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampVector.java index dea0510789e..53bcbc0aacf 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampVector.java @@ -77,34 +77,6 @@ public long get(int index) throws IllegalStateException { return valueBuffer.getLong(index * TYPE_WIDTH); } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TimeStampVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFromSafe(int, int, TimeStampVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TimeStampVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java index 673740ded50..97a0ea18b16 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java @@ -148,34 +148,6 @@ public Byte getObject(int index) { } } - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, TinyIntVector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final byte value = from.valueBuffer.getByte(fromIndex * TYPE_WIDTH); - valueBuffer.setByte(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, TinyIntVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, TinyIntVector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java index a57e1ef169b..5bf843c91c4 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java @@ -146,24 +146,6 @@ public Short getObjectNoOverflow(int index) { } } - /** - * Copies the value at fromIndex to thisIndex (including validity). - */ - public void copyFrom(int fromIndex, int thisIndex, UInt1Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final byte value = from.valueBuffer.getByte(fromIndex * TYPE_WIDTH); - valueBuffer.setByte(thisIndex * TYPE_WIDTH, value); - } - - /** - * Identical to {@link #copyFrom(int, int, UInt1Vector)} but reallocates buffer if index is larger - * than capacity. - */ - public void copyFromSafe(int fromIndex, int thisIndex, UInt1Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java index c29adf6aecb..7e7c5441ecb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java @@ -127,22 +127,6 @@ public Character getObject(int index) { } } - /** Copies a value and validity bit from the given vector to this one. */ - public void copyFrom(int fromIndex, int thisIndex, UInt2Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final char value = from.valueBuffer.getChar(fromIndex * TYPE_WIDTH); - valueBuffer.setChar(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, UInt2Vector)} but reallocate buffer if - * index is larger than capacity. - */ - public void copyFromSafe(int fromIndex, int thisIndex, UInt2Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java index 65d0e32dd96..525d82c33b1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java @@ -144,25 +144,6 @@ public Long getObjectNoOverflow(int index) { } } - /** - * Copies a value and validity setting to the thisIndex position from the given vector - * at fromIndex. - */ - public void copyFrom(int fromIndex, int thisIndex, UInt4Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final int value = from.valueBuffer.getInt(fromIndex * TYPE_WIDTH); - valueBuffer.setInt(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, UInt4Vector)} but will allocate additional space - * if fromIndex is larger than current capacity. - */ - public void copyFromSafe(int fromIndex, int thisIndex, UInt4Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java index e7561326da0..5bd54c0451b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java @@ -149,25 +149,6 @@ public BigInteger getObjectNoOverflow(int index) { } } - /** - * Copy a value and validity setting from fromIndex in from to this - * Vector at thisIndex. - */ - public void copyFrom(int fromIndex, int thisIndex, UInt8Vector from) { - BitVectorHelper.setValidityBit(validityBuffer, thisIndex, from.isSet(fromIndex)); - final long value = from.valueBuffer.getLong(fromIndex * TYPE_WIDTH); - valueBuffer.setLong(thisIndex * TYPE_WIDTH, value); - } - - /** - * Same as {@link #copyFrom(int, int, UInt8Vector)} but reallocates if thisIndex is - * larger then current capacity. - */ - public void copyFromSafe(int fromIndex, int thisIndex, UInt8Vector from) { - handleSafe(thisIndex); - copyFrom(fromIndex, thisIndex, from); - } - /*----------------------------------------------------------------* | | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java index ac6df0dedcf..bd76f3cc03f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java @@ -163,48 +163,6 @@ public void get(int index, NullableVarBinaryHolder holder) { *----------------------------------------------------------------*/ - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, VarBinaryVector from) { - final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); - final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); - final int length = end - start; - fillHoles(thisIndex); - BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, from.isSet(fromIndex)); - final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); - from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); - offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); - lastSet = thisIndex; - } - - /** - * Same as {@link #copyFrom(int, int, VarBinaryVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, VarBinaryVector from) { - final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); - final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); - final int length = end - start; - handleSafe(thisIndex, length); - fillHoles(thisIndex); - BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, from.isSet(fromIndex)); - final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); - from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); - offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); - lastSet = thisIndex; - } - /** * Set the variable length element at the specified index to the data * buffer supplied in the holder. diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java index 3e035fa6449..c012ce3cf30 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java @@ -162,48 +162,6 @@ public void get(int index, NullableVarCharHolder holder) { *----------------------------------------------------------------*/ - /** - * Copy a cell value from a particular index in source vector to a particular - * position in this vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFrom(int fromIndex, int thisIndex, VarCharVector from) { - final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); - final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); - final int length = end - start; - fillHoles(thisIndex); - BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, from.isSet(fromIndex)); - final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); - from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); - offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); - lastSet = thisIndex; - } - - /** - * Same as {@link #copyFrom(int, int, VarCharVector)} except that - * it handles the case when the capacity of the vector needs to be expanded - * before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - */ - public void copyFromSafe(int fromIndex, int thisIndex, VarCharVector from) { - final int start = from.offsetBuffer.getInt(fromIndex * OFFSET_WIDTH); - final int end = from.offsetBuffer.getInt((fromIndex + 1) * OFFSET_WIDTH); - final int length = end - start; - handleSafe(thisIndex, length); - fillHoles(thisIndex); - BitVectorHelper.setValidityBit(this.validityBuffer, thisIndex, from.isSet(fromIndex)); - final int copyStart = offsetBuffer.getInt(thisIndex * OFFSET_WIDTH); - from.valueBuffer.getBytes(start, this.valueBuffer, copyStart, length); - offsetBuffer.setInt((thisIndex + 1) * OFFSET_WIDTH, copyStart + length); - lastSet = thisIndex; - } - /** * Set the variable length element at the specified index to the data * buffer supplied in the holder. From e33a2b69a469b0293d1e8a8cbad5cf24218a0b0e Mon Sep 17 00:00:00 2001 From: tianchen Date: Tue, 2 Jul 2019 21:20:02 -0700 Subject: [PATCH 37/52] ARROW-5812: [Java] Refactor method name and param type in BaseIntVector Related to [ARROW-5812](https://issues.apache.org/jira/browse/ARROW-5812). Change to void setWithPossibleTruncate(int index, long value); for better generality. Author: tianchen Closes #4763 from tianchen92/ARROW-5726-follow and squashes the following commits: c26941a77 remove warnings 18b907736 silent truncate 9bbeb3d9b refactor --- .../main/java/org/apache/arrow/vector/BaseIntVector.java | 4 ++-- .../main/java/org/apache/arrow/vector/BigIntVector.java | 2 +- .../src/main/java/org/apache/arrow/vector/IntVector.java | 4 ++-- .../main/java/org/apache/arrow/vector/SmallIntVector.java | 6 ++---- .../main/java/org/apache/arrow/vector/TinyIntVector.java | 6 ++---- .../src/main/java/org/apache/arrow/vector/UInt1Vector.java | 7 +++---- .../src/main/java/org/apache/arrow/vector/UInt2Vector.java | 6 ++---- .../src/main/java/org/apache/arrow/vector/UInt4Vector.java | 4 ++-- .../src/main/java/org/apache/arrow/vector/UInt8Vector.java | 2 +- .../apache/arrow/vector/dictionary/DictionaryEncoder.java | 2 +- 10 files changed, 18 insertions(+), 25 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java index 74387de9486..6b5b53aa726 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseIntVector.java @@ -23,7 +23,7 @@ public interface BaseIntVector extends ValueVector { /** - * set the encoded value from a {@link org.apache.arrow.vector.dictionary.Dictionary}. + * set value at specific index, note this value may need to be need truncated. */ - void setEncodedValue(int index, int value); + void setWithPossibleTruncate(int index, long value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java index 1001cec6679..2c75a8d7557 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java @@ -323,7 +323,7 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { + public void setWithPossibleTruncate(int index, long value) { this.setSafe(index, value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java index 4ce5e1b92f8..e4deabaf015 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java @@ -327,8 +327,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java index 165b774c8dc..f1f064bdb03 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java @@ -20,7 +20,6 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.SmallIntReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableSmallIntHolder; @@ -355,9 +354,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - Preconditions.checkArgument(value <= Short.MAX_VALUE, "value is overflow: %s", value); - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java index 97a0ea18b16..d74adf24aad 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java @@ -20,7 +20,6 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.TinyIntReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableTinyIntHolder; @@ -356,9 +355,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - Preconditions.checkArgument(value <= Byte.MAX_VALUE, "value is overflow: %s", value); - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java index 5bf843c91c4..8e6d4fe5811 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java @@ -20,7 +20,6 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.UInt1ReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableUInt1Holder; @@ -32,6 +31,7 @@ import io.netty.buffer.ArrowBuf; + /** * UInt1Vector implements a fixed width (1 bytes) vector of * integer values which could be null. A validity buffer (bit vector) is @@ -318,9 +318,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - Preconditions.checkArgument(value <= 0xFF, "value is overflow: %s", value); - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java index 7e7c5441ecb..53bb16753ec 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java @@ -20,7 +20,6 @@ import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.complex.impl.UInt2ReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.NullableUInt2Holder; @@ -299,9 +298,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - Preconditions.checkArgument(value <= Character.MAX_VALUE, "value is overflow: %s", value); - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java index 525d82c33b1..b54b3f78bbd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java @@ -288,8 +288,8 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { - this.setSafe(index, value); + public void setWithPossibleTruncate(int index, long value) { + this.setSafe(index, (int) value); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java index 5bd54c0451b..2451ab5b2e2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java @@ -289,7 +289,7 @@ public TransferPair makeTransferPair(ValueVector to) { } @Override - public void setEncodedValue(int index, int value) { + public void setWithPossibleTruncate(int index, long value) { this.setSafe(index, value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 698191c2ca2..b9f547c8fb5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -78,7 +78,7 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { if (encoded == null) { throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value); } - indices.setEncodedValue(i, encoded); + indices.setWithPossibleTruncate(i, encoded); } } From ecbe116b778aa02488d1a24c7073ebf7079f5bbe Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Tue, 2 Jul 2019 21:48:06 -0700 Subject: [PATCH 38/52] ARROW-5658: [JAVA] Sync schema for VectorSchemaRoot Resolve JIRA ARROW-5658. The fundamental problem is that, as data are inserted to the vector (e.g. ListVector), the schema of VectorShemaRoot can be different from the vector structure. In the server side, the deserialization is based on the schema, which is out of date, so it fails silently. In this PR, we fix the problem of obsolete schema and make the server print error info explicitly. Author: liyafan82 Closes #4689 from liyafan82/fly_5658 and squashes the following commits: e7d5865a5 Undo throwing exception 3ea77bc76 Replace automatic updating schema with throwing an exception cb9da2032 Merge branch 'master' into fly_5658 47a776fbe Automatically update schema 6d2763848 Resolve comments 061e8bc2f Fix error log e8ea49f00 Sync schema for VectorSchemaRoot --- .../apache/arrow/flight/FlightService.java | 1 + .../apache/arrow/vector/VectorSchemaRoot.java | 22 ++++- .../arrow/vector/TestVectorSchemaRoot.java | 93 +++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java index ee45cef24d3..e805917cd8f 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java @@ -191,6 +191,7 @@ public StreamObserver doPutCustom(final StreamObserver fieldVectors; private final Map fieldVectorsMap = new HashMap<>(); @@ -206,4 +206,24 @@ public String contentToTSVString() { } return sb.toString(); } + + /** + * Synchronizes the schema from the current vectors. + * In some cases, the schema and the actual vector structure may be different. + * This can be caused by a promoted writer (For details, please see + * {@link org.apache.arrow.vector.complex.impl.PromotableWriter}). + * For example, when writing different types of data to a {@link org.apache.arrow.vector.complex.ListVector} + * may lead to such a case. + * When this happens, this method should be called to bring the schema and vector structure in a synchronized state. + * @return true if the schema is updated, false otherwise. + */ + public boolean syncSchema() { + List oldFields = this.schema.getFields(); + List newFields = this.fieldVectors.stream().map(ValueVector::getField).collect(Collectors.toList()); + if (!oldFields.equals(newFields)) { + this.schema = new Schema(newFields); + return true; + } + return false; + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java index 480dcaca0f4..f9525f45c1a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java @@ -17,10 +17,24 @@ package org.apache.arrow.vector; +import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -75,4 +89,83 @@ private void checkCount(BitVector vec1, IntVector vec2, VectorSchemaRoot vsr, in assertEquals(vec2.getValueCount(), count); assertEquals(vsr.getRowCount(), count); } + + private VectorSchemaRoot createBatch() { + FieldType varCharType = new FieldType(true, new ArrowType.Utf8(), /*dictionary=*/null); + FieldType listType = new FieldType(true, new ArrowType.List(), /*dictionary=*/null); + + // create the schema + List schemaFields = new ArrayList<>(); + Field childField = new Field("varCharCol", varCharType, null); + List childFields = new ArrayList<>(); + childFields.add(childField); + schemaFields.add(new Field("listCol", listType, childFields)); + Schema schema = new Schema(schemaFields); + + VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, allocator); + // get and allocate the vector + ListVector vector = (ListVector) schemaRoot.getVector("listCol"); + vector.allocateNew(); + + // write data to the vector + UnionListWriter writer = vector.getWriter(); + + writer.setPosition(0); + + // write data vector(0) + writer.startList(); + + // write data vector(0)(0) + writer.list().startList(); + + // According to the schema above, the list element should have varchar type. + // When we write a big int, the original writer cannot handle this, so the writer will + // be promoted, and the vector structure will be different from the schema. + writer.list().bigInt().writeBigInt(0); + writer.list().bigInt().writeBigInt(1); + writer.list().endList(); + + // write data vector(0)(1) + writer.list().startList(); + writer.list().float8().writeFloat8(3.0D); + writer.list().float8().writeFloat8(7.0D); + writer.list().endList(); + + // finish data vector(0) + writer.endList(); + + writer.setPosition(1); + + // write data vector(1) + writer.startList(); + + // write data vector(1)(0) + writer.list().startList(); + writer.list().integer().writeInt(3); + writer.list().integer().writeInt(2); + writer.list().endList(); + + // finish data vector(1) + writer.endList(); + + vector.setValueCount(2); + + return schemaRoot; + } + + @Test + public void testSchemaSync() { + //create vector schema root + try (VectorSchemaRoot schemaRoot = createBatch()) { + Schema newSchema = new Schema( + schemaRoot.getFieldVectors().stream().map(vec -> vec.getField()).collect(Collectors.toList())); + + assertNotEquals(newSchema, schemaRoot.getSchema()); + assertTrue(schemaRoot.syncSchema()); + assertEquals(newSchema, schemaRoot.getSchema()); + + // no schema update this time. + assertFalse(schemaRoot.syncSchema()); + } + } } From 86f0a3a215379dcdcb28595ede712a16599a9990 Mon Sep 17 00:00:00 2001 From: Kenta Murata Date: Wed, 3 Jul 2019 11:26:52 +0200 Subject: [PATCH 39/52] ARROW-5813: [C++] Fix TensorEquals for different contiguous tensors This change makes TensorEquals correctly calculate the equality of a row-major tensor and a column-major tensor. Author: Kenta Murata Closes #4774 from mrkn/tensor_equals_for_different_contiguous and squashes the following commits: d40c3c774 Add unequal expectations e7ef17d82 Fix TensorEquals for different contiguous tensors --- cpp/src/arrow/compare.cc | 8 +++++- cpp/src/arrow/tensor-test.cc | 50 ++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 4ae5d897917..e1525a4f4d6 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -970,7 +970,13 @@ bool TensorEquals(const Tensor& left, const Tensor& right) { } else if (left.size() == 0) { are_equal = true; } else { - if (!left.is_contiguous() || !right.is_contiguous()) { + const bool left_row_major_p = left.is_row_major(); + const bool left_column_major_p = left.is_column_major(); + const bool right_row_major_p = right.is_row_major(); + const bool right_column_major_p = right.is_column_major(); + + if (!(left_row_major_p && right_row_major_p) && + !(left_column_major_p && right_column_major_p)) { const auto& shape = left.shape(); if (shape != right.shape()) { are_equal = false; diff --git a/cpp/src/arrow/tensor-test.cc b/cpp/src/arrow/tensor-test.cc index 36e97434d28..4638cd7739b 100644 --- a/cpp/src/arrow/tensor-test.cc +++ b/cpp/src/arrow/tensor-test.cc @@ -155,6 +155,56 @@ TEST(TestTensor, CountNonZeroForNonContiguousTensor) { AssertCountNonZero(t, 8); } +TEST(TestTensor, Equals) { + std::vector shape = {4, 4}; + + std::vector c_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector c_strides = {32, 8}; + Tensor tc1(int64(), Buffer::Wrap(c_values), shape, c_strides); + Tensor tc2(int64(), Buffer::Wrap(c_values), shape, c_strides); + + std::vector f_values = {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}; + Tensor tc3(int64(), Buffer::Wrap(f_values), shape, c_strides); + + std::vector f_strides = {8, 32}; + Tensor tf1(int64(), Buffer::Wrap(f_values), shape, f_strides); + Tensor tf2(int64(), Buffer::Wrap(c_values), shape, f_strides); + + std::vector nc_values = {1, 0, 5, 0, 9, 0, 13, 0, 2, 0, 6, 0, 10, 0, 14, 0, + 3, 0, 7, 0, 11, 0, 15, 0, 4, 0, 8, 0, 12, 0, 16, 0}; + std::vector nc_strides = {16, 64}; + Tensor tnc(int64(), Buffer::Wrap(nc_values), shape, nc_strides); + + ASSERT_TRUE(tc1.is_contiguous()); + ASSERT_TRUE(tc1.is_row_major()); + + ASSERT_TRUE(tf1.is_contiguous()); + ASSERT_TRUE(tf1.is_column_major()); + + ASSERT_FALSE(tnc.is_contiguous()); + + // same object + EXPECT_TRUE(tc1.Equals(tc1)); + EXPECT_TRUE(tf1.Equals(tf1)); + EXPECT_TRUE(tnc.Equals(tnc)); + + // different objects + EXPECT_TRUE(tc1.Equals(tc2)); + EXPECT_FALSE(tc1.Equals(tc3)); + + // row-major and column-major + EXPECT_TRUE(tc1.Equals(tf1)); + EXPECT_FALSE(tc3.Equals(tf1)); + + // row-major and non-contiguous + EXPECT_TRUE(tc1.Equals(tnc)); + EXPECT_FALSE(tc3.Equals(tnc)); + + // column-major and non-contiguous + EXPECT_TRUE(tf1.Equals(tnc)); + EXPECT_FALSE(tf2.Equals(tnc)); +} + TEST(TestNumericTensor, ElementAccessWithRowMajorStrides) { std::vector shape = {3, 4}; From a9a82ec7d390f95c7e590fb463c0b5f6773d8e35 Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Wed, 3 Jul 2019 12:14:41 +0200 Subject: [PATCH 40/52] ARROW-4036: [C++] Pluggable Status message, by exposing an abstract delegate class. This provides less "pluggability" but I think still offers a clean model for extension (subsystems can wrap the constructor for there purposes, and provide external static methods to check for particular types of errors). Author: Micah Kornfield Author: Antoine Pitrou Closes #4484 from emkornfield/status_code_proposal and squashes the following commits: 4d1ab8d1d don't import plasma errors directly into top level pyarrow module a66f999f8 make format 040216d48 fixes for comments outside python 729bba1ff Fix Py2 issues (hopefully) ea56d1e6a Fix PythonErrorDetail to store Python error state (and restore it in check_status()) 21e1b95ac fix compilation 9c905b094 fix lint 74d563cd7 fixes 85786efb1 change messages 3626a9016 try removing message a4e6a1ff2 add logging for debug 4586fd1e2 fix typo 8f011b329 fix status propagation 317ea9c66 fix complie 9f5916070 don't make_shared inline 484b3a232 style fix 14e3467b5 dont rely on rtti cd22df64d format dec458506 not-quite pluggable error codes --- c_glib/arrow-glib/error.cpp | 13 +- c_glib/arrow-glib/error.h | 12 +- .../test/plasma/test-plasma-created-object.rb | 2 +- cpp/src/arrow/compute/kernels/cast.cc | 2 +- cpp/src/arrow/csv/column-builder.cc | 2 +- cpp/src/arrow/python/common.cc | 180 +++++++++++++----- cpp/src/arrow/python/common.h | 25 ++- cpp/src/arrow/python/python-test.cc | 106 ++++++++--- cpp/src/arrow/python/serialize.cc | 2 +- cpp/src/arrow/status-test.cc | 29 +++ cpp/src/arrow/status.cc | 32 ++-- cpp/src/arrow/status.h | 78 +++----- cpp/src/plasma/client.cc | 9 +- cpp/src/plasma/common.cc | 81 ++++++++ cpp/src/plasma/common.h | 17 ++ cpp/src/plasma/protocol.cc | 9 +- cpp/src/plasma/test/client_tests.cc | 4 +- cpp/src/plasma/test/serialization_tests.cc | 4 +- python/pyarrow/__init__.py | 3 +- python/pyarrow/_plasma.pyx | 94 ++++++--- python/pyarrow/error.pxi | 29 +-- python/pyarrow/includes/common.pxd | 5 +- python/pyarrow/includes/libarrow.pxd | 7 + python/pyarrow/includes/libplasma.pxd | 25 +++ python/pyarrow/plasma.py | 4 +- python/pyarrow/tests/test_array.py | 2 +- python/pyarrow/tests/test_convert_builtin.py | 73 +++++-- python/pyarrow/tests/test_plasma.py | 4 +- 28 files changed, 586 insertions(+), 267 deletions(-) create mode 100644 python/pyarrow/includes/libplasma.pxd diff --git a/c_glib/arrow-glib/error.cpp b/c_glib/arrow-glib/error.cpp index a56b6ec3d13..4c1461543f8 100644 --- a/c_glib/arrow-glib/error.cpp +++ b/c_glib/arrow-glib/error.cpp @@ -65,22 +65,15 @@ garrow_error_code(const arrow::Status &status) return GARROW_ERROR_NOT_IMPLEMENTED; case arrow::StatusCode::SerializationError: return GARROW_ERROR_SERIALIZATION; - case arrow::StatusCode::PythonError: - return GARROW_ERROR_PYTHON; - case arrow::StatusCode::PlasmaObjectExists: - return GARROW_ERROR_PLASMA_OBJECT_EXISTS; - case arrow::StatusCode::PlasmaObjectNonexistent: - return GARROW_ERROR_PLASMA_OBJECT_NONEXISTENT; - case arrow::StatusCode::PlasmaStoreFull: - return GARROW_ERROR_PLASMA_STORE_FULL; - case arrow::StatusCode::PlasmaObjectAlreadySealed: - return GARROW_ERROR_PLASMA_OBJECT_ALREADY_SEALED; case arrow::StatusCode::CodeGenError: return GARROW_ERROR_CODE_GENERATION; case arrow::StatusCode::ExpressionValidationError: return GARROW_ERROR_EXPRESSION_VALIDATION; case arrow::StatusCode::ExecutionError: return GARROW_ERROR_EXECUTION; + case arrow::StatusCode::AlreadyExists: + return GARROW_ERROR_ALREADY_EXISTS; + default: return GARROW_ERROR_UNKNOWN; } diff --git a/c_glib/arrow-glib/error.h b/c_glib/arrow-glib/error.h index 3dea9fc2e10..2fac5ad0d3e 100644 --- a/c_glib/arrow-glib/error.h +++ b/c_glib/arrow-glib/error.h @@ -35,15 +35,11 @@ G_BEGIN_DECLS * @GARROW_ERROR_UNKNOWN: Unknown error. * @GARROW_ERROR_NOT_IMPLEMENTED: The feature is not implemented. * @GARROW_ERROR_SERIALIZATION: Serialization error. - * @GARROW_ERROR_PYTHON: Python error. - * @GARROW_ERROR_PLASMA_OBJECT_EXISTS: Object already exists on Plasma. - * @GARROW_ERROR_PLASMA_OBJECT_NONEXISTENT: Object doesn't exist on Plasma. - * @GARROW_ERROR_PLASMA_STORE_FULL: Store full error on Plasma. - * @GARROW_ERROR_PLASMA_OBJECT_ALREADY_SEALED: Object already sealed on Plasma. * @GARROW_ERROR_CODE_GENERATION: Error generating code for expression evaluation * in Gandiva. * @GARROW_ERROR_EXPRESSION_VALIDATION: Validation errors in expression given for code generation. * @GARROW_ERROR_EXECUTION: Execution error while evaluating the expression against a record batch. + * @GARROW_ALREADY_EXISTS: Item already exists error. * * The error codes are used by all arrow-glib functions. * @@ -60,14 +56,10 @@ typedef enum { GARROW_ERROR_UNKNOWN = 9, GARROW_ERROR_NOT_IMPLEMENTED, GARROW_ERROR_SERIALIZATION, - GARROW_ERROR_PYTHON, - GARROW_ERROR_PLASMA_OBJECT_EXISTS = 20, - GARROW_ERROR_PLASMA_OBJECT_NONEXISTENT, - GARROW_ERROR_PLASMA_STORE_FULL, - GARROW_ERROR_PLASMA_OBJECT_ALREADY_SEALED, GARROW_ERROR_CODE_GENERATION = 40, GARROW_ERROR_EXPRESSION_VALIDATION = 41, GARROW_ERROR_EXECUTION = 42, + GARROW_ERROR_ALREADY_EXISTS = 45, } GArrowError; #define GARROW_ERROR garrow_error_quark() diff --git a/c_glib/test/plasma/test-plasma-created-object.rb b/c_glib/test/plasma/test-plasma-created-object.rb index 9025ff4ac22..857322d20e1 100644 --- a/c_glib/test/plasma/test-plasma-created-object.rb +++ b/c_glib/test/plasma/test-plasma-created-object.rb @@ -45,7 +45,7 @@ def teardown test("#abort") do @object.data.set_data(0, @data) - assert_raise(Arrow::Error::PlasmaObjectExists) do + assert_raise(Arrow::Error::AlreadyExists) do @client.create(@id, @data.bytesize, @options) end @object.abort diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 299ca80402c..93feb656dd5 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -52,7 +52,7 @@ if (ARROW_PREDICT_FALSE(!_s.ok())) { \ std::stringstream ss; \ ss << __FILE__ << ":" << __LINE__ << " code: " << #s << "\n" << _s.message(); \ - ctx->SetStatus(Status(_s.code(), ss.str())); \ + ctx->SetStatus(Status(_s.code(), ss.str(), s.detail())); \ return; \ } \ } while (0) diff --git a/cpp/src/arrow/csv/column-builder.cc b/cpp/src/arrow/csv/column-builder.cc index 657aa6f4e96..4099507016d 100644 --- a/cpp/src/arrow/csv/column-builder.cc +++ b/cpp/src/arrow/csv/column-builder.cc @@ -76,7 +76,7 @@ class TypedColumnBuilder : public ColumnBuilder { } else { std::stringstream ss; ss << "In column #" << col_index_ << ": " << st.message(); - return Status(st.code(), ss.str()); + return Status(st.code(), ss.str(), st.detail()); } } diff --git a/cpp/src/arrow/python/common.cc b/cpp/src/arrow/python/common.cc index aa44ec07e65..3cebc03cd22 100644 --- a/cpp/src/arrow/python/common.cc +++ b/cpp/src/arrow/python/common.cc @@ -23,11 +23,15 @@ #include "arrow/memory_pool.h" #include "arrow/status.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/python/helpers.h" namespace arrow { + +using internal::checked_cast; + namespace py { static std::mutex memory_pool_mutex; @@ -47,6 +51,129 @@ MemoryPool* get_memory_pool() { } } +// ---------------------------------------------------------------------- +// PythonErrorDetail + +namespace { + +const char kErrorDetailTypeId[] = "arrow::py::PythonErrorDetail"; + +// Try to match the Python exception type with an appropriate Status code +StatusCode MapPyError(PyObject* exc_type) { + StatusCode code; + + if (PyErr_GivenExceptionMatches(exc_type, PyExc_MemoryError)) { + code = StatusCode::OutOfMemory; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_IndexError)) { + code = StatusCode::IndexError; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_KeyError)) { + code = StatusCode::KeyError; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_TypeError)) { + code = StatusCode::TypeError; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_ValueError) || + PyErr_GivenExceptionMatches(exc_type, PyExc_OverflowError)) { + code = StatusCode::Invalid; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_EnvironmentError)) { + code = StatusCode::IOError; + } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_NotImplementedError)) { + code = StatusCode::NotImplemented; + } else { + code = StatusCode::UnknownError; + } + return code; +} + +// PythonErrorDetail indicates a Python exception was raised. +class PythonErrorDetail : public StatusDetail { + public: + const char* type_id() const override { return kErrorDetailTypeId; } + + std::string ToString() const override { + // This is simple enough not to need the GIL + const auto ty = reinterpret_cast(exc_type_.obj()); + // XXX Should we also print traceback? + return std::string("Python exception: ") + ty->tp_name; + } + + void RestorePyError() const { + Py_INCREF(exc_type_.obj()); + Py_INCREF(exc_value_.obj()); + Py_INCREF(exc_traceback_.obj()); + PyErr_Restore(exc_type_.obj(), exc_value_.obj(), exc_traceback_.obj()); + } + + PyObject* exc_type() const { return exc_type_.obj(); } + + PyObject* exc_value() const { return exc_value_.obj(); } + + static std::shared_ptr FromPyError() { + PyObject* exc_type = nullptr; + PyObject* exc_value = nullptr; + PyObject* exc_traceback = nullptr; + + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + PyErr_NormalizeException(&exc_type, &exc_value, &exc_traceback); + ARROW_CHECK(exc_type) + << "PythonErrorDetail::FromPyError called without a Python error set"; + DCHECK(PyType_Check(exc_type)); + DCHECK(exc_value); // Ensured by PyErr_NormalizeException, double-check + if (exc_traceback == nullptr) { + // Needed by PyErr_Restore() + Py_INCREF(Py_None); + exc_traceback = Py_None; + } + + std::shared_ptr detail(new PythonErrorDetail); + detail->exc_type_.reset(exc_type); + detail->exc_value_.reset(exc_value); + detail->exc_traceback_.reset(exc_traceback); + return detail; + } + + protected: + PythonErrorDetail() = default; + + OwnedRefNoGIL exc_type_, exc_value_, exc_traceback_; +}; + +} // namespace + +// ---------------------------------------------------------------------- +// Python exception <-> Status + +Status ConvertPyError(StatusCode code) { + auto detail = PythonErrorDetail::FromPyError(); + if (code == StatusCode::UnknownError) { + code = MapPyError(detail->exc_type()); + } + + std::string message; + RETURN_NOT_OK(internal::PyObject_StdStringStr(detail->exc_value(), &message)); + return Status(code, message, detail); +} + +Status PassPyError() { + if (PyErr_Occurred()) { + return ConvertPyError(); + } + return Status::OK(); +} + +bool IsPyError(const Status& status) { + if (status.ok()) { + return false; + } + auto detail = status.detail(); + bool result = detail != nullptr && detail->type_id() == kErrorDetailTypeId; + return result; +} + +void RestorePyError(const Status& status) { + ARROW_CHECK(IsPyError(status)); + const auto& detail = checked_cast(*status.detail()); + detail.RestorePyError(); +} + // ---------------------------------------------------------------------- // PyBuffer @@ -64,7 +191,7 @@ Status PyBuffer::Init(PyObject* obj) { } return Status::OK(); } else { - return Status(StatusCode::PythonError, ""); + return ConvertPyError(StatusCode::Invalid); } } @@ -83,56 +210,5 @@ PyBuffer::~PyBuffer() { } } -// ---------------------------------------------------------------------- -// Python exception -> Status - -Status ConvertPyError(StatusCode code) { - PyObject* exc_type = nullptr; - PyObject* exc_value = nullptr; - PyObject* traceback = nullptr; - - PyErr_Fetch(&exc_type, &exc_value, &traceback); - PyErr_NormalizeException(&exc_type, &exc_value, &traceback); - - DCHECK_NE(exc_type, nullptr) << "ConvertPyError called without an exception set"; - - OwnedRef exc_type_ref(exc_type); - OwnedRef exc_value_ref(exc_value); - OwnedRef traceback_ref(traceback); - - std::string message; - RETURN_NOT_OK(internal::PyObject_StdStringStr(exc_value, &message)); - - if (code == StatusCode::UnknownError) { - // Try to match the Python exception type with an appropriate Status code - if (PyErr_GivenExceptionMatches(exc_type, PyExc_MemoryError)) { - code = StatusCode::OutOfMemory; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_IndexError)) { - code = StatusCode::IndexError; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_KeyError)) { - code = StatusCode::KeyError; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_TypeError)) { - code = StatusCode::TypeError; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_ValueError) || - PyErr_GivenExceptionMatches(exc_type, PyExc_OverflowError)) { - code = StatusCode::Invalid; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_EnvironmentError)) { - code = StatusCode::IOError; - } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_NotImplementedError)) { - code = StatusCode::NotImplemented; - } - } - return Status(code, message); -} - -Status PassPyError() { - if (PyErr_Occurred()) { - // Do not call PyErr_Clear, the assumption is that someone further - // up the call stack will want to deal with the Python error. - return Status(StatusCode::PythonError, ""); - } - return Status::OK(); -} - } // namespace py } // namespace arrow diff --git a/cpp/src/arrow/python/common.h b/cpp/src/arrow/python/common.h index 766b76418de..9d3dc0c05ef 100644 --- a/cpp/src/arrow/python/common.h +++ b/cpp/src/arrow/python/common.h @@ -36,7 +36,15 @@ class Result; namespace py { +// Convert current Python error to a Status. The Python error state is cleared +// and can be restored with RestorePyError(). ARROW_PYTHON_EXPORT Status ConvertPyError(StatusCode code = StatusCode::UnknownError); +// Same as ConvertPyError(), but returns Status::OK() if no Python error is set. +ARROW_PYTHON_EXPORT Status PassPyError(); +// Query whether the given Status is a Python error (as wrapped by ConvertPyError()). +ARROW_PYTHON_EXPORT bool IsPyError(const Status& status); +// Restore a Python error wrapped in a Status. +ARROW_PYTHON_EXPORT void RestorePyError(const Status& status); // Catch a pending Python exception and return the corresponding Status. // If no exception is pending, Status::OK() is returned. @@ -48,9 +56,6 @@ inline Status CheckPyError(StatusCode code = StatusCode::UnknownError) { } } -ARROW_PYTHON_EXPORT Status PassPyError(); - -// TODO(wesm): We can just let errors pass through. To be explored later #define RETURN_IF_PYERROR() ARROW_RETURN_NOT_OK(CheckPyError()); #define PY_RETURN_IF_ERROR(CODE) ARROW_RETURN_NOT_OK(CheckPyError(CODE)); @@ -97,6 +102,18 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL { ARROW_DISALLOW_COPY_AND_ASSIGN(PyAcquireGIL); }; +// A RAII-style helper that releases the GIL until the end of a lexical block +class ARROW_PYTHON_EXPORT PyReleaseGIL { + public: + PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); } + + ~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); } + + private: + PyThreadState* saved_state_; + ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL); +}; + // A helper to call safely into the Python interpreter from arbitrary C++ code. // The GIL is acquired, and the current thread's error status is preserved. template @@ -109,7 +126,7 @@ Status SafeCallIntoPython(Function&& func) { Status st = std::forward(func)(); // If the return Status is a "Python error", the current Python error status // describes the error and shouldn't be clobbered. - if (!st.IsPythonError() && exc_type != NULLPTR) { + if (!IsPyError(st) && exc_type != NULLPTR) { PyErr_Restore(exc_type, exc_value, exc_traceback); } return st; diff --git a/cpp/src/arrow/python/python-test.cc b/cpp/src/arrow/python/python-test.cc index 5de613f0e50..5027d3fe3f6 100644 --- a/cpp/src/arrow/python/python-test.cc +++ b/cpp/src/arrow/python/python-test.cc @@ -40,21 +40,12 @@ using internal::checked_cast; namespace py { -TEST(PyBuffer, InvalidInputObject) { - std::shared_ptr res; - PyObject* input = Py_None; - auto old_refcnt = Py_REFCNT(input); - ASSERT_RAISES(PythonError, PyBuffer::FromPyObject(input, &res)); - PyErr_Clear(); - ASSERT_EQ(old_refcnt, Py_REFCNT(input)); -} - TEST(OwnedRef, TestMoves) { - PyAcquireGIL lock; std::vector vec; PyObject *u, *v; u = PyList_New(0); v = PyList_New(0); + { OwnedRef ref(u); vec.push_back(std::move(ref)); @@ -66,31 +57,42 @@ TEST(OwnedRef, TestMoves) { } TEST(OwnedRefNoGIL, TestMoves) { - std::vector vec; - PyObject *u, *v; - { - PyAcquireGIL lock; - u = PyList_New(0); - v = PyList_New(0); - } + PyAcquireGIL lock; + lock.release(); + { - OwnedRefNoGIL ref(u); - vec.push_back(std::move(ref)); - ASSERT_EQ(ref.obj(), nullptr); + std::vector vec; + PyObject *u, *v; + { + lock.acquire(); + u = PyList_New(0); + v = PyList_New(0); + lock.release(); + } + { + OwnedRefNoGIL ref(u); + vec.push_back(std::move(ref)); + ASSERT_EQ(ref.obj(), nullptr); + } + vec.emplace_back(v); + ASSERT_EQ(Py_REFCNT(u), 1); + ASSERT_EQ(Py_REFCNT(v), 1); } - vec.emplace_back(v); - ASSERT_EQ(Py_REFCNT(u), 1); - ASSERT_EQ(Py_REFCNT(v), 1); } TEST(CheckPyError, TestStatus) { - PyAcquireGIL lock; Status st; - auto check_error = [](Status& st, const char* expected_message = "some error") { + auto check_error = [](Status& st, const char* expected_message = "some error", + const char* expected_detail = nullptr) { st = CheckPyError(); ASSERT_EQ(st.message(), expected_message); ASSERT_FALSE(PyErr_Occurred()); + if (expected_detail) { + auto detail = st.detail(); + ASSERT_NE(detail, nullptr); + ASSERT_EQ(detail->ToString(), expected_detail); + } }; for (PyObject* exc_type : {PyExc_Exception, PyExc_SyntaxError}) { @@ -100,7 +102,7 @@ TEST(CheckPyError, TestStatus) { } PyErr_SetString(PyExc_TypeError, "some error"); - check_error(st); + check_error(st, "some error", "Python exception: TypeError"); ASSERT_TRUE(st.IsTypeError()); PyErr_SetString(PyExc_ValueError, "some error"); @@ -118,7 +120,7 @@ TEST(CheckPyError, TestStatus) { } PyErr_SetString(PyExc_NotImplementedError, "some error"); - check_error(st); + check_error(st, "some error", "Python exception: NotImplementedError"); ASSERT_TRUE(st.IsNotImplemented()); // No override if a specific status code is given @@ -129,6 +131,52 @@ TEST(CheckPyError, TestStatus) { ASSERT_FALSE(PyErr_Occurred()); } +TEST(CheckPyError, TestStatusNoGIL) { + PyAcquireGIL lock; + { + Status st; + PyErr_SetString(PyExc_ZeroDivisionError, "zzzt"); + st = ConvertPyError(); + ASSERT_FALSE(PyErr_Occurred()); + lock.release(); + ASSERT_TRUE(st.IsUnknownError()); + ASSERT_EQ(st.message(), "zzzt"); + ASSERT_EQ(st.detail()->ToString(), "Python exception: ZeroDivisionError"); + } +} + +TEST(RestorePyError, Basics) { + PyErr_SetString(PyExc_ZeroDivisionError, "zzzt"); + auto st = ConvertPyError(); + ASSERT_FALSE(PyErr_Occurred()); + ASSERT_TRUE(st.IsUnknownError()); + ASSERT_EQ(st.message(), "zzzt"); + ASSERT_EQ(st.detail()->ToString(), "Python exception: ZeroDivisionError"); + + RestorePyError(st); + ASSERT_TRUE(PyErr_Occurred()); + PyObject* exc_type; + PyObject* exc_value; + PyObject* exc_traceback; + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + ASSERT_TRUE(PyErr_GivenExceptionMatches(exc_type, PyExc_ZeroDivisionError)); + std::string py_message; + ASSERT_OK(internal::PyObject_StdStringStr(exc_value, &py_message)); + ASSERT_EQ(py_message, "zzzt"); +} + +TEST(PyBuffer, InvalidInputObject) { + std::shared_ptr res; + PyObject* input = Py_None; + auto old_refcnt = Py_REFCNT(input); + { + Status st = PyBuffer::FromPyObject(input, &res); + ASSERT_TRUE(IsPyError(st)) << st.ToString(); + ASSERT_FALSE(PyErr_Occurred()); + } + ASSERT_EQ(old_refcnt, Py_REFCNT(input)); +} + class DecimalTest : public ::testing::Test { public: DecimalTest() : lock_(), decimal_constructor_() { @@ -253,8 +301,6 @@ TEST(PandasConversionTest, TestObjectBlockWriteFails) { } TEST(BuiltinConversionTest, TestMixedTypeFails) { - PyAcquireGIL lock; - OwnedRef list_ref(PyList_New(3)); PyObject* list = list_ref.obj(); @@ -405,8 +451,6 @@ TEST_F(DecimalTest, TestMixedPrecisionAndScale) { } TEST_F(DecimalTest, TestMixedPrecisionAndScaleSequenceConvert) { - PyAcquireGIL lock; - PyObject* value1 = this->CreatePythonDecimal("0.01").detach(); ASSERT_NE(value1, nullptr); diff --git a/cpp/src/arrow/python/serialize.cc b/cpp/src/arrow/python/serialize.cc index d93e3954e41..57843943775 100644 --- a/cpp/src/arrow/python/serialize.cc +++ b/cpp/src/arrow/python/serialize.cc @@ -332,8 +332,8 @@ Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict, Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* elem, PyObject** result) { - *result = NULL; if (context == Py_None) { + *result = NULL; return Status::SerializationError("error while calling callback on ", internal::PyObject_StdStringRepr(elem), ": handler not registered"); diff --git a/cpp/src/arrow/status-test.cc b/cpp/src/arrow/status-test.cc index b7fc61f4801..b151e462b28 100644 --- a/cpp/src/arrow/status-test.cc +++ b/cpp/src/arrow/status-test.cc @@ -23,6 +23,16 @@ namespace arrow { +namespace { + +class TestStatusDetail : public StatusDetail { + public: + const char* type_id() const override { return "type_id"; } + std::string ToString() const override { return "a specific detail message"; } +}; + +} // namespace + TEST(StatusTest, TestCodeAndMessage) { Status ok = Status::OK(); ASSERT_EQ(StatusCode::OK, ok.code()); @@ -40,6 +50,25 @@ TEST(StatusTest, TestToString) { ASSERT_EQ(file_error.ToString(), ss.str()); } +TEST(StatusTest, TestToStringWithDetail) { + Status status(StatusCode::IOError, "summary", std::make_shared()); + ASSERT_EQ("IOError: summary. Detail: a specific detail message", status.ToString()); + + std::stringstream ss; + ss << status; + ASSERT_EQ(status.ToString(), ss.str()); +} + +TEST(StatusTest, TestWithDetail) { + Status status(StatusCode::IOError, "summary"); + auto detail = std::make_shared(); + Status new_status = status.WithDetail(detail); + + ASSERT_EQ(new_status.code(), status.code()); + ASSERT_EQ(new_status.message(), status.message()); + ASSERT_EQ(new_status.detail(), detail); +} + TEST(StatusTest, AndStatus) { Status a = Status::OK(); Status b = Status::OK(); diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index cbb29119be6..785db459752 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -21,11 +21,17 @@ namespace arrow { -Status::Status(StatusCode code, const std::string& msg) { +Status::Status(StatusCode code, const std::string& msg) + : Status::Status(code, msg, nullptr) {} + +Status::Status(StatusCode code, std::string msg, std::shared_ptr detail) { ARROW_CHECK_NE(code, StatusCode::OK) << "Cannot construct ok status with message"; state_ = new State; state_->code = code; - state_->msg = msg; + state_->msg = std::move(msg); + if (detail != nullptr) { + state_->detail = std::move(detail); + } } void Status::CopyFrom(const Status& s) { @@ -77,21 +83,6 @@ std::string Status::CodeAsString() const { case StatusCode::SerializationError: type = "Serialization error"; break; - case StatusCode::PythonError: - type = "Python error"; - break; - case StatusCode::PlasmaObjectExists: - type = "Plasma object exists"; - break; - case StatusCode::PlasmaObjectNonexistent: - type = "Plasma object is nonexistent"; - break; - case StatusCode::PlasmaStoreFull: - type = "Plasma store is full"; - break; - case StatusCode::PlasmaObjectAlreadySealed: - type = "Plasma object is already sealed"; - break; case StatusCode::CodeGenError: type = "CodeGenError in Gandiva"; break; @@ -110,11 +101,16 @@ std::string Status::CodeAsString() const { std::string Status::ToString() const { std::string result(CodeAsString()); - if (state_ == NULL) { + if (state_ == nullptr) { return result; } result += ": "; result += state_->msg; + if (state_->detail != nullptr) { + result += ". Detail: "; + result += state_->detail->ToString(); + } + return result; } diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 1ed0da65fc4..7cafc41902d 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -85,17 +86,13 @@ enum class StatusCode : char { UnknownError = 9, NotImplemented = 10, SerializationError = 11, - PythonError = 12, RError = 13, - PlasmaObjectExists = 20, - PlasmaObjectNonexistent = 21, - PlasmaStoreFull = 22, - PlasmaObjectAlreadySealed = 23, - StillExecuting = 24, // Gandiva range of errors CodeGenError = 40, ExpressionValidationError = 41, - ExecutionError = 42 + ExecutionError = 42, + // Continue generic codes. + AlreadyExists = 45 }; #if defined(__clang__) @@ -103,6 +100,17 @@ enum class StatusCode : char { class ARROW_MUST_USE_RESULT ARROW_EXPORT Status; #endif +/// \brief An opaque class that allows subsystems to retain +/// additional information inside the Status. +class ARROW_EXPORT StatusDetail { + public: + virtual ~StatusDetail() = default; + // Return a unique id for the type of the StatusDetail + // (effectively a poor man's substitude for RTTI). + virtual const char* type_id() const = 0; + virtual std::string ToString() const = 0; +}; + /// \brief Status outcome object (success or error) /// /// The Status object is an object holding the outcome of an operation. @@ -124,6 +132,8 @@ class ARROW_EXPORT Status { } Status(StatusCode code, const std::string& msg); + /// \brief Pluggable constructor for use by sub-systems. detail cannot be null. + Status(StatusCode code, std::string msg, std::shared_ptr detail); // Copy the specified status. inline Status(const Status& s); @@ -221,32 +231,6 @@ class ARROW_EXPORT Status { return Status(StatusCode::RError, util::StringBuilder(std::forward(args)...)); } - template - static Status PlasmaObjectExists(Args&&... args) { - return Status(StatusCode::PlasmaObjectExists, - util::StringBuilder(std::forward(args)...)); - } - - template - static Status PlasmaObjectNonexistent(Args&&... args) { - return Status(StatusCode::PlasmaObjectNonexistent, - util::StringBuilder(std::forward(args)...)); - } - - template - static Status PlasmaObjectAlreadySealed(Args&&... args) { - return Status(StatusCode::PlasmaObjectAlreadySealed, - util::StringBuilder(std::forward(args)...)); - } - - template - static Status PlasmaStoreFull(Args&&... args) { - return Status(StatusCode::PlasmaStoreFull, - util::StringBuilder(std::forward(args)...)); - } - - static Status StillExecuting() { return Status(StatusCode::StillExecuting, ""); } - template static Status CodeGenError(Args&&... args) { return Status(StatusCode::CodeGenError, @@ -290,22 +274,6 @@ class ARROW_EXPORT Status { bool IsSerializationError() const { return code() == StatusCode::SerializationError; } /// Return true iff the status indicates a R-originated error. bool IsRError() const { return code() == StatusCode::RError; } - /// Return true iff the status indicates a Python-originated error. - bool IsPythonError() const { return code() == StatusCode::PythonError; } - /// Return true iff the status indicates an already existing Plasma object. - bool IsPlasmaObjectExists() const { return code() == StatusCode::PlasmaObjectExists; } - /// Return true iff the status indicates a non-existent Plasma object. - bool IsPlasmaObjectNonexistent() const { - return code() == StatusCode::PlasmaObjectNonexistent; - } - /// Return true iff the status indicates an already sealed Plasma object. - bool IsPlasmaObjectAlreadySealed() const { - return code() == StatusCode::PlasmaObjectAlreadySealed; - } - /// Return true iff the status indicates the Plasma store reached its capacity limit. - bool IsPlasmaStoreFull() const { return code() == StatusCode::PlasmaStoreFull; } - - bool IsStillExecuting() const { return code() == StatusCode::StillExecuting; } bool IsCodeGenError() const { return code() == StatusCode::CodeGenError; } @@ -330,6 +298,17 @@ class ARROW_EXPORT Status { /// \brief Return the specific error message attached to this status. std::string message() const { return ok() ? "" : state_->msg; } + /// \brief Return the status detail attached to this message. + std::shared_ptr detail() const { + return state_ == NULLPTR ? NULLPTR : state_->detail; + } + + /// \brief Returns a new Status copying the existing status, but + /// updating with the existing detail. + Status WithDetail(std::shared_ptr new_detail) { + return Status(code(), message(), std::move(new_detail)); + } + [[noreturn]] void Abort() const; [[noreturn]] void Abort(const std::string& message) const; @@ -341,6 +320,7 @@ class ARROW_EXPORT Status { struct State { StatusCode code; std::string msg; + std::shared_ptr detail; }; // OK status has a `NULL` state_. Otherwise, `state_` points to // a `State` structure containing the error code and message(s) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index ce9795d20fc..a6cdf7f17ca 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -791,11 +791,12 @@ Status PlasmaClient::Impl::Seal(const ObjectID& object_id) { auto object_entry = objects_in_use_.find(object_id); if (object_entry == objects_in_use_.end()) { - return Status::PlasmaObjectNonexistent( - "Seal() called on an object without a reference to it"); + return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNonexistent, + "Seal() called on an object without a reference to it"); } if (object_entry->second->is_sealed) { - return Status::PlasmaObjectAlreadySealed("Seal() called on an already sealed object"); + return MakePlasmaError(PlasmaErrorCode::PlasmaObjectAlreadySealed, + "Seal() called on an already sealed object"); } object_entry->second->is_sealed = true; @@ -896,7 +897,7 @@ Status PlasmaClient::Impl::Hash(const ObjectID& object_id, uint8_t* digest) { RETURN_NOT_OK(Get({object_id}, 0, &object_buffers)); // If the object was not retrieved, return false. if (!object_buffers[0].data) { - return Status::PlasmaObjectNonexistent("Object not found"); + return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNonexistent, "Object not found"); } // Compute the hash. uint64_t hash = ComputeObjectHash(object_buffers[0]); diff --git a/cpp/src/plasma/common.cc b/cpp/src/plasma/common.cc index 0f1a0d1b505..bbcd2c9c3f1 100644 --- a/cpp/src/plasma/common.cc +++ b/cpp/src/plasma/common.cc @@ -18,6 +18,7 @@ #include "plasma/common.h" #include +#include #include "arrow/util/ubsan.h" @@ -27,8 +28,88 @@ namespace fb = plasma::flatbuf; namespace plasma { +namespace { + +const char kErrorDetailTypeId[] = "plasma::PlasmaStatusDetail"; + +class PlasmaStatusDetail : public arrow::StatusDetail { + public: + explicit PlasmaStatusDetail(PlasmaErrorCode code) : code_(code) {} + const char* type_id() const override { return kErrorDetailTypeId; } + std::string ToString() const override { + const char* type; + switch (code()) { + case PlasmaErrorCode::PlasmaObjectExists: + type = "Plasma object exists"; + break; + case PlasmaErrorCode::PlasmaObjectNonexistent: + type = "Plasma object is nonexistent"; + break; + case PlasmaErrorCode::PlasmaStoreFull: + type = "Plasma store is full"; + break; + case PlasmaErrorCode::PlasmaObjectAlreadySealed: + type = "Plasma object is already sealed"; + break; + default: + type = "Unknown plasma error"; + break; + } + return std::string(type); + } + PlasmaErrorCode code() const { return code_; } + + private: + PlasmaErrorCode code_; +}; + +bool IsPlasmaStatus(const arrow::Status& status, PlasmaErrorCode code) { + if (status.ok()) { + return false; + } + auto* detail = status.detail().get(); + return detail != nullptr && detail->type_id() == kErrorDetailTypeId && + static_cast(detail)->code() == code; +} + +} // namespace + using arrow::Status; +arrow::Status MakePlasmaError(PlasmaErrorCode code, std::string message) { + arrow::StatusCode arrow_code = arrow::StatusCode::UnknownError; + switch (code) { + case PlasmaErrorCode::PlasmaObjectExists: + arrow_code = arrow::StatusCode::AlreadyExists; + break; + case PlasmaErrorCode::PlasmaObjectNonexistent: + arrow_code = arrow::StatusCode::KeyError; + break; + case PlasmaErrorCode::PlasmaStoreFull: + arrow_code = arrow::StatusCode::CapacityError; + break; + case PlasmaErrorCode::PlasmaObjectAlreadySealed: + // Maybe a stretch? + arrow_code = arrow::StatusCode::TypeError; + break; + } + return arrow::Status(arrow_code, std::move(message), + std::make_shared(code)); +} + +bool IsPlasmaObjectExists(const arrow::Status& status) { + return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectExists); +} +bool IsPlasmaObjectNonexistent(const arrow::Status& status) { + return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectNonexistent); +} +bool IsPlasmaObjectAlreadySealed(const arrow::Status& status) { + return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectAlreadySealed); +} +bool IsPlasmaStoreFull(const arrow::Status& status) { + return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaStoreFull); +} + UniqueID UniqueID::from_binary(const std::string& binary) { UniqueID id; std::memcpy(&id, binary.data(), sizeof(id)); diff --git a/cpp/src/plasma/common.h b/cpp/src/plasma/common.h index 6f4cef5becb..d42840cfbd2 100644 --- a/cpp/src/plasma/common.h +++ b/cpp/src/plasma/common.h @@ -41,6 +41,23 @@ namespace plasma { enum class ObjectLocation : int32_t { Local, Remote, Nonexistent }; +enum class PlasmaErrorCode : int8_t { + PlasmaObjectExists = 1, + PlasmaObjectNonexistent = 2, + PlasmaStoreFull = 3, + PlasmaObjectAlreadySealed = 4, +}; + +ARROW_EXPORT arrow::Status MakePlasmaError(PlasmaErrorCode code, std::string message); +/// Return true iff the status indicates an already existing Plasma object. +ARROW_EXPORT bool IsPlasmaObjectExists(const arrow::Status& status); +/// Return true iff the status indicates a non-existent Plasma object. +ARROW_EXPORT bool IsPlasmaObjectNonexistent(const arrow::Status& status); +/// Return true iff the status indicates an already sealed Plasma object. +ARROW_EXPORT bool IsPlasmaObjectAlreadySealed(const arrow::Status& status); +/// Return true iff the status indicates the Plasma store reached its capacity limit. +ARROW_EXPORT bool IsPlasmaStoreFull(const arrow::Status& status); + constexpr int64_t kUniqueIDSize = 20; class ARROW_EXPORT UniqueID { diff --git a/cpp/src/plasma/protocol.cc b/cpp/src/plasma/protocol.cc index b87656bd097..c22d77d6019 100644 --- a/cpp/src/plasma/protocol.cc +++ b/cpp/src/plasma/protocol.cc @@ -86,11 +86,14 @@ Status PlasmaErrorStatus(fb::PlasmaError plasma_error) { case fb::PlasmaError::OK: return Status::OK(); case fb::PlasmaError::ObjectExists: - return Status::PlasmaObjectExists("object already exists in the plasma store"); + return MakePlasmaError(PlasmaErrorCode::PlasmaObjectExists, + "object already exists in the plasma store"); case fb::PlasmaError::ObjectNonexistent: - return Status::PlasmaObjectNonexistent("object does not exist in the plasma store"); + return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNonexistent, + "object does not exist in the plasma store"); case fb::PlasmaError::OutOfMemory: - return Status::PlasmaStoreFull("object does not fit in the plasma store"); + return MakePlasmaError(PlasmaErrorCode::PlasmaStoreFull, + "object does not fit in the plasma store"); default: ARROW_LOG(FATAL) << "unknown plasma error code " << static_cast(plasma_error); } diff --git a/cpp/src/plasma/test/client_tests.cc b/cpp/src/plasma/test/client_tests.cc index 435b687a69e..deffde57976 100644 --- a/cpp/src/plasma/test/client_tests.cc +++ b/cpp/src/plasma/test/client_tests.cc @@ -157,7 +157,7 @@ TEST_F(TestPlasmaStore, SealErrorsTest) { ObjectID object_id = random_object_id(); Status result = client_.Seal(object_id); - ASSERT_TRUE(result.IsPlasmaObjectNonexistent()); + ASSERT_TRUE(IsPlasmaObjectNonexistent(result)); // Create object. std::vector data(100, 0); @@ -165,7 +165,7 @@ TEST_F(TestPlasmaStore, SealErrorsTest) { // Trying to seal it again. result = client_.Seal(object_id); - ASSERT_TRUE(result.IsPlasmaObjectAlreadySealed()); + ASSERT_TRUE(IsPlasmaObjectAlreadySealed(result)); ARROW_CHECK_OK(client_.Release(object_id)); } diff --git a/cpp/src/plasma/test/serialization_tests.cc b/cpp/src/plasma/test/serialization_tests.cc index 7e2bc887ed3..f3cff428582 100644 --- a/cpp/src/plasma/test/serialization_tests.cc +++ b/cpp/src/plasma/test/serialization_tests.cc @@ -156,7 +156,7 @@ TEST_F(TestPlasmaSerialization, SealReply) { ObjectID object_id2; Status s = ReadSealReply(data.data(), data.size(), &object_id2); ASSERT_EQ(object_id1, object_id2); - ASSERT_TRUE(s.IsPlasmaObjectExists()); + ASSERT_TRUE(IsPlasmaObjectExists(s)); close(fd); } @@ -234,7 +234,7 @@ TEST_F(TestPlasmaSerialization, ReleaseReply) { ObjectID object_id2; Status s = ReadReleaseReply(data.data(), data.size(), &object_id2); ASSERT_EQ(object_id1, object_id2); - ASSERT_TRUE(s.IsPlasmaObjectExists()); + ASSERT_TRUE(IsPlasmaObjectExists(s)); close(fd); } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index bbbd91a9508..1d508ed7d11 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -122,8 +122,7 @@ def parse_git(root, **kwargs): ArrowMemoryError, ArrowNotImplementedError, ArrowTypeError, - ArrowSerializationError, - PlasmaObjectExists) + ArrowSerializationError) # Serialization from pyarrow.lib import (deserialize_from, deserialize, diff --git a/python/pyarrow/_plasma.pyx b/python/pyarrow/_plasma.pyx index e352377f14e..7e994c3ee07 100644 --- a/python/pyarrow/_plasma.pyx +++ b/python/pyarrow/_plasma.pyx @@ -37,8 +37,10 @@ import warnings import pyarrow from pyarrow.lib cimport Buffer, NativeFile, check_status, pyarrow_wrap_buffer +from pyarrow.lib import ArrowException from pyarrow.includes.libarrow cimport (CBuffer, CMutableBuffer, CFixedSizeBufferWriter, CStatus) +from pyarrow.includes.libplasma cimport * from pyarrow import compat @@ -255,6 +257,34 @@ cdef class PlasmaBuffer(Buffer): self.client._release(self.object_id) +class PlasmaObjectNonexistent(ArrowException): + pass + + +class PlasmaStoreFull(ArrowException): + pass + + +class PlasmaObjectExists(ArrowException): + pass + + +cdef int plasma_check_status(const CStatus& status) nogil except -1: + if status.ok(): + return 0 + + with gil: + message = compat.frombytes(status.message()) + if IsPlasmaObjectExists(status): + raise PlasmaObjectExists(message) + elif IsPlasmaObjectNonexistent(status): + raise PlasmaObjectNonexistent(message) + elif IsPlasmaStoreFull(status): + raise PlasmaStoreFull(message) + + return check_status(status) + + cdef class PlasmaClient: """ The PlasmaClient is used to interface with a plasma store and manager. @@ -283,7 +313,7 @@ cdef class PlasmaClient: for object_id in object_ids: ids.push_back(object_id.data) with nogil: - check_status(self.client.get().Get(ids, timeout_ms, result)) + plasma_check_status(self.client.get().Get(ids, timeout_ms, result)) # XXX C++ API should instead expose some kind of CreateAuto() cdef _make_mutable_plasma_buffer(self, ObjectID object_id, uint8_t* data, @@ -325,9 +355,10 @@ cdef class PlasmaClient: """ cdef shared_ptr[CBuffer] data with nogil: - check_status(self.client.get().Create(object_id.data, data_size, - (metadata.data()), - metadata.size(), &data)) + plasma_check_status( + self.client.get().Create(object_id.data, data_size, + (metadata.data()), + metadata.size(), &data)) return self._make_mutable_plasma_buffer(object_id, data.get().mutable_data(), data_size) @@ -358,8 +389,9 @@ cdef class PlasmaClient: enough objects to create room for it. """ with nogil: - check_status(self.client.get().CreateAndSeal(object_id.data, data, - metadata)) + plasma_check_status( + self.client.get().CreateAndSeal(object_id.data, data, + metadata)) def get_buffers(self, object_ids, timeout_ms=-1, with_meta=False): """ @@ -554,7 +586,7 @@ cdef class PlasmaClient: A string used to identify an object. """ with nogil: - check_status(self.client.get().Seal(object_id.data)) + plasma_check_status(self.client.get().Seal(object_id.data)) def _release(self, ObjectID object_id): """ @@ -566,7 +598,7 @@ cdef class PlasmaClient: A string used to identify an object. """ with nogil: - check_status(self.client.get().Release(object_id.data)) + plasma_check_status(self.client.get().Release(object_id.data)) def contains(self, ObjectID object_id): """ @@ -579,8 +611,8 @@ cdef class PlasmaClient: """ cdef c_bool is_contained with nogil: - check_status(self.client.get().Contains(object_id.data, - &is_contained)) + plasma_check_status(self.client.get().Contains(object_id.data, + &is_contained)) return is_contained def hash(self, ObjectID object_id): @@ -600,8 +632,8 @@ cdef class PlasmaClient: """ cdef c_vector[uint8_t] digest = c_vector[uint8_t](kDigestSize) with nogil: - check_status(self.client.get().Hash(object_id.data, - digest.data())) + plasma_check_status(self.client.get().Hash(object_id.data, + digest.data())) return bytes(digest[:]) def evict(self, int64_t num_bytes): @@ -617,13 +649,15 @@ cdef class PlasmaClient: """ cdef int64_t num_bytes_evicted = -1 with nogil: - check_status(self.client.get().Evict(num_bytes, num_bytes_evicted)) + plasma_check_status( + self.client.get().Evict(num_bytes, num_bytes_evicted)) return num_bytes_evicted def subscribe(self): """Subscribe to notifications about sealed objects.""" with nogil: - check_status(self.client.get().Subscribe(&self.notification_fd)) + plasma_check_status( + self.client.get().Subscribe(&self.notification_fd)) def get_notification_socket(self): """ @@ -650,11 +684,11 @@ cdef class PlasmaClient: cdef int64_t data_size cdef int64_t metadata_size with nogil: - check_status(self.client.get() - .DecodeNotification(buf, - &object_id, - &data_size, - &metadata_size)) + status = self.client.get().DecodeNotification(buf, + &object_id, + &data_size, + &metadata_size) + plasma_check_status(status) return ObjectID(object_id.binary()), data_size, metadata_size def get_next_notification(self): @@ -674,11 +708,11 @@ cdef class PlasmaClient: cdef int64_t data_size cdef int64_t metadata_size with nogil: - check_status(self.client.get() - .GetNotification(self.notification_fd, - &object_id.data, - &data_size, - &metadata_size)) + status = self.client.get().GetNotification(self.notification_fd, + &object_id.data, + &data_size, + &metadata_size) + plasma_check_status(status) return object_id, data_size, metadata_size def to_capsule(self): @@ -689,7 +723,7 @@ cdef class PlasmaClient: Disconnect this client from the Plasma store. """ with nogil: - check_status(self.client.get().Disconnect()) + plasma_check_status(self.client.get().Disconnect()) def delete(self, object_ids): """ @@ -705,7 +739,7 @@ cdef class PlasmaClient: for object_id in object_ids: ids.push_back(object_id.data) with nogil: - check_status(self.client.get().Delete(ids)) + plasma_check_status(self.client.get().Delete(ids)) def list(self): """ @@ -738,7 +772,7 @@ cdef class PlasmaClient: """ cdef CObjectTable objects with nogil: - check_status(self.client.get().List(&objects)) + plasma_check_status(self.client.get().List(&objects)) result = dict() cdef ObjectID object_id cdef CObjectTableEntry entry @@ -802,7 +836,7 @@ def connect(store_socket_name, manager_socket_name=None, int release_delay=0, warnings.warn("release_delay in PlasmaClient.connect is deprecated", FutureWarning) with nogil: - check_status(result.client.get() - .Connect(result.store_socket_name, b"", - release_delay, num_retries)) + plasma_check_status( + result.client.get().Connect(result.store_socket_name, b"", + release_delay, num_retries)) return result diff --git a/python/pyarrow/error.pxi b/python/pyarrow/error.pxi index 7b5e8d43371..3cb9142d479 100644 --- a/python/pyarrow/error.pxi +++ b/python/pyarrow/error.pxi @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from pyarrow.includes.libarrow cimport CStatus +from pyarrow.includes.libarrow cimport CStatus, IsPyError, RestorePyError from pyarrow.includes.common cimport c_string from pyarrow.compat import frombytes @@ -56,30 +56,21 @@ class ArrowIndexError(IndexError, ArrowException): pass -class PlasmaObjectExists(ArrowException): - pass - - -class PlasmaObjectNonexistent(ArrowException): - pass - - -class PlasmaStoreFull(ArrowException): - pass - - class ArrowSerializationError(ArrowException): pass +# This function could be written directly in C++ if we didn't +# define Arrow-specific subclasses (ArrowInvalid etc.) cdef int check_status(const CStatus& status) nogil except -1: if status.ok(): return 0 - if status.IsPythonError(): - return -1 - with gil: + if IsPyError(status): + RestorePyError(status) + return -1 + message = frombytes(status.message()) if status.IsInvalid(): raise ArrowInvalid(message) @@ -97,12 +88,6 @@ cdef int check_status(const CStatus& status) nogil except -1: raise ArrowCapacityError(message) elif status.IsIndexError(): raise ArrowIndexError(message) - elif status.IsPlasmaObjectExists(): - raise PlasmaObjectExists(message) - elif status.IsPlasmaObjectNonexistent(): - raise PlasmaObjectNonexistent(message) - elif status.IsPlasmaStoreFull(): - raise PlasmaStoreFull(message) elif status.IsSerializationError(): raise ArrowSerializationError(message) else: diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 4a06fc82065..8b116f60b6e 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -42,6 +42,7 @@ cdef extern from "numpy/halffloat.h": cdef extern from "arrow/api.h" namespace "arrow" nogil: # We can later add more of the common status factory methods as needed cdef CStatus CStatus_OK "arrow::Status::OK"() + cdef CStatus CStatus_Invalid "arrow::Status::Invalid"() cdef CStatus CStatus_NotImplemented \ "arrow::Status::NotImplemented"(const c_string& msg) @@ -64,10 +65,6 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: c_bool IsCapacityError() c_bool IsIndexError() c_bool IsSerializationError() - c_bool IsPythonError() - c_bool IsPlasmaObjectExists() - c_bool IsPlasmaObjectNonexistent() - c_bool IsPlasmaStoreFull() cdef extern from "arrow/result.h" namespace "arrow::internal" nogil: cdef cppclass CResult[T]: diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 93a75945ce3..89199ca77fb 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1217,6 +1217,8 @@ cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil: CMemoryPool* pool c_bool from_pandas + # TODO Some functions below are not actually "nogil" + CStatus ConvertPySequence(object obj, object mask, const PyConversionOptions& options, shared_ptr[CChunkedArray]* out) @@ -1342,6 +1344,11 @@ cdef extern from 'arrow/python/init.h': int arrow_init_numpy() except -1 +cdef extern from 'arrow/python/common.h' namespace "arrow::py": + c_bool IsPyError(const CStatus& status) + void RestorePyError(const CStatus& status) + + cdef extern from 'arrow/python/pyarrow.h' namespace 'arrow::py': int import_pyarrow() except -1 diff --git a/python/pyarrow/includes/libplasma.pxd b/python/pyarrow/includes/libplasma.pxd new file mode 100644 index 00000000000..1b84ab4e0a6 --- /dev/null +++ b/python/pyarrow/includes/libplasma.pxd @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * + +cdef extern from "plasma/common.h" namespace "plasma" nogil: + cdef c_bool IsPlasmaObjectExists(const CStatus& status) + cdef c_bool IsPlasmaObjectNonexistent(const CStatus& status) + cdef c_bool IsPlasmaStoreFull(const CStatus& status) diff --git a/python/pyarrow/plasma.py b/python/pyarrow/plasma.py index 748de97c363..43ca471e0b2 100644 --- a/python/pyarrow/plasma.py +++ b/python/pyarrow/plasma.py @@ -27,7 +27,9 @@ import time from pyarrow._plasma import (ObjectID, ObjectNotAvailable, # noqa - PlasmaBuffer, PlasmaClient, connect) + PlasmaBuffer, PlasmaClient, connect, + PlasmaObjectExists, PlasmaObjectNonexistent, + PlasmaStoreFull) # The Plasma TensorFlow Operator needs to be compiled on the end user's diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 9d66d96e2c2..f961c00b7ac 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -1479,7 +1479,7 @@ def test_array_masked(): def test_array_from_large_pyints(): # ARROW-5430 - with pytest.raises(pa.ArrowInvalid): + with pytest.raises(OverflowError): # too large for int64 so dtype must be explicitly provided pa.array([int(2 ** 63)]) diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py index 4e040836979..81d5952b4b1 100644 --- a/python/pyarrow/tests/test_convert_builtin.py +++ b/python/pyarrow/tests/test_convert_builtin.py @@ -26,9 +26,12 @@ import datetime import decimal import itertools +import traceback +import sys + import numpy as np -import six import pytz +import six int_type_pairs = [ @@ -53,6 +56,19 @@ def __iter__(self): return self.lst.__iter__() +class MyInt: + def __init__(self, value): + self.value = value + + def __int__(self): + return self.value + + +class MyBrokenInt: + def __int__(self): + 1/0 # MARKER + + def check_struct_type(ty, expected): """ Check a struct type is as expected, but not taking order into account. @@ -191,7 +207,7 @@ def test_nested_lists(seq): @parametrize_with_iterable_types def test_list_with_non_list(seq): # List types don't accept non-sequences - with pytest.raises(pa.ArrowTypeError): + with pytest.raises(TypeError): pa.array(seq([[], [1, 2], 3]), type=pa.list_(pa.int64())) @@ -299,6 +315,24 @@ def test_sequence_numpy_integer_inferred(seq, np_scalar_pa_type): assert arr.to_pylist() == expected +@parametrize_with_iterable_types +def test_sequence_custom_integers(seq): + expected = [0, 42, 2**33 + 1, -2**63] + data = list(map(MyInt, expected)) + arr = pa.array(seq(data), type=pa.int64()) + assert arr.to_pylist() == expected + + +@parametrize_with_iterable_types +def test_broken_integers(seq): + data = [MyBrokenInt()] + with pytest.raises(ZeroDivisionError) as exc_info: + pa.array(seq(data), type=pa.int64()) + # Original traceback is kept + tb_lines = traceback.format_tb(exc_info.tb) + assert "# MARKER" in tb_lines[-1] + + def test_numpy_scalars_mixed_type(): # ARROW-4324 data = [np.int32(10), np.float32(0.5)] @@ -308,7 +342,7 @@ def test_numpy_scalars_mixed_type(): @pytest.mark.xfail(reason="Type inference for uint64 not implemented", - raises=pa.ArrowException) + raises=OverflowError) def test_uint64_max_convert(): data = [0, np.iinfo(np.uint64).max] @@ -323,20 +357,20 @@ def test_uint64_max_convert(): @pytest.mark.parametrize("bits", [8, 16, 32, 64]) def test_signed_integer_overflow(bits): ty = getattr(pa, "int%d" % bits)() - # XXX ideally would raise OverflowError - with pytest.raises((ValueError, pa.ArrowException)): + # XXX ideally would always raise OverflowError + with pytest.raises((OverflowError, pa.ArrowInvalid)): pa.array([2 ** (bits - 1)], ty) - with pytest.raises((ValueError, pa.ArrowException)): + with pytest.raises((OverflowError, pa.ArrowInvalid)): pa.array([-2 ** (bits - 1) - 1], ty) @pytest.mark.parametrize("bits", [8, 16, 32, 64]) def test_unsigned_integer_overflow(bits): ty = getattr(pa, "uint%d" % bits)() - # XXX ideally would raise OverflowError - with pytest.raises((ValueError, pa.ArrowException)): + # XXX ideally would always raise OverflowError + with pytest.raises((OverflowError, pa.ArrowInvalid)): pa.array([2 ** bits], ty) - with pytest.raises((ValueError, pa.ArrowException)): + with pytest.raises((OverflowError, pa.ArrowInvalid)): pa.array([-1], ty) @@ -661,7 +695,7 @@ def test_sequence_explicit_types(input): def test_date32_overflow(): # Overflow data3 = [2**32, None] - with pytest.raises(pa.ArrowException): + with pytest.raises((OverflowError, pa.ArrowException)): pa.array(data3, type=pa.date32()) @@ -831,12 +865,19 @@ def test_sequence_timestamp_from_int_with_unit(): assert repr(arr_ns[0]) == "Timestamp('1970-01-01 00:00:00.000000001')" assert str(arr_ns[0]) == "1970-01-01 00:00:00.000000001" - with pytest.raises(pa.ArrowException): - class CustomClass(): - pass - pa.array([1, CustomClass()], type=ns) - pa.array([1, CustomClass()], type=pa.date32()) - pa.array([1, CustomClass()], type=pa.date64()) + if sys.version_info >= (3,): + expected_exc = TypeError + else: + # Can have "AttributeError: CustomClass instance + # has no attribute '__trunc__'" + expected_exc = (TypeError, AttributeError) + + class CustomClass(): + pass + + for ty in [ns, pa.date32(), pa.date64()]: + with pytest.raises(expected_exc): + pa.array([1, CustomClass()], type=ty) def test_sequence_nesting_levels(): diff --git a/python/pyarrow/tests/test_plasma.py b/python/pyarrow/tests/test_plasma.py index 149bdd54f6c..49808a19ef4 100644 --- a/python/pyarrow/tests/test_plasma.py +++ b/python/pyarrow/tests/test_plasma.py @@ -227,7 +227,7 @@ def test_create_and_seal(self): # Make sure that creating the same object twice raises an exception. object_id = random_object_id() self.plasma_client.create_and_seal(object_id, b'a', b'b') - with pytest.raises(pa.PlasmaObjectExists): + with pytest.raises(pa.plasma.PlasmaObjectExists): self.plasma_client.create_and_seal(object_id, b'a', b'b') # Make sure that these objects can be evicted. @@ -852,7 +852,7 @@ def test_use_full_memory(self): for _ in range(2): create_object(self.plasma_client2, DEFAULT_PLASMA_STORE_MEMORY, 0) # Verify that an object that is too large does not fit. - with pytest.raises(pa.lib.PlasmaStoreFull): + with pytest.raises(pa.plasma.PlasmaStoreFull): create_object(self.plasma_client2, DEFAULT_PLASMA_STORE_MEMORY + SMALL_OBJECT_SIZE, 0) From d03074986c4333ef8b5b27e1b72d79a75c42fe26 Mon Sep 17 00:00:00 2001 From: Johan Peltenburg Date: Wed, 3 Jul 2019 22:37:48 +0200 Subject: [PATCH 41/52] [Website] Update Fletcher link and description (#4794) --- site/powered_by.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/site/powered_by.md b/site/powered_by.md index cdbede25ac3..e1a474c68eb 100644 --- a/site/powered_by.md +++ b/site/powered_by.md @@ -70,11 +70,11 @@ short description of your use case. Dremio reads data from any source (RDBMS, HDFS, S3, NoSQL) into Arrow buffers, and provides fast SQL access via ODBC, JDBC, and REST for BI, Python, R, and more (all backed by Apache Arrow). -* **[Fletcher][20]:** Fletcher is an FPGA acceleration framework that can - convert an Arrow schema into an easy-to-use hardware interface. The - accelerator can request data from Arrow tables by supplying row indices. - In turn, the interface provides streams of data of the types defined - through the schema. Furthermore, Arrow alleviates serialization bottlenecks. +* **[Fletcher][20]:** Fletcher is a framework that can integrate FPGA + accelerators with tools and frameworks that use the Apache Arrow in-memory + format. From a set of Arrow Schemas, Fletcher generates highly optimized + hardware structures that allow accelerator kernels to read and write + RecordBatches at system bandwidth through easy-to-use interfaces. * **[GeoMesa][8]:** A suite of tools that enables large-scale geospatial query and analytics on distributed computing systems. GeoMesa supports query results in the Arrow IPC format, which can then be used for in-browser @@ -163,7 +163,7 @@ short description of your use case. [17]: https://github.com/red-data-tools/red-arrow/ [18]: https://www.graphistry.com [19]: http://gpuopenanalytics.com -[20]: https://github.com/johanpel/fletcher +[20]: https://github.com/abs-tudelft/fletcher [21]: https://www.paradigm4.com [22]: https://github.com/Paradigm4/stream [23]: https://github.com/jpmorganchase/perspective From 4e099a8948d49c46cbc91d10e98d311dd4ca8098 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 3 Jul 2019 15:43:11 -0700 Subject: [PATCH 42/52] fix merge --- cpp/src/plasma/client.cc | 2 +- cpp/src/plasma/io/connection.cc | 4 ++-- cpp/src/plasma/store.cc | 9 ++++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index a07420bf639..2a9c4e6c344 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -931,7 +931,7 @@ Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* o int64_t* metadata_size) { std::lock_guard guard(client_mutex_); - auto object_info = flatbuffers::GetRoot(buffer); + auto object_info = flatbuffers::GetRoot(buffer); if (object_info->object_id()->size() != sizeof(ObjectID)) { return Status::Invalid( "The size of ObjectID in the message is different from the size " diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc index 4823385db78..348e382877c 100644 --- a/cpp/src/plasma/io/connection.cc +++ b/cpp/src/plasma/io/connection.cc @@ -94,7 +94,7 @@ Status ServerConnection::ReadMessage(int64_t type, std::vector* message } // If there was no error, make sure the protocol version matches. if (read_version != kPlasmaProtocolVersion) { - return Status::ProtocolError( + return Status::IOError( "Expected Plasma message protocol version: ", kPlasmaProtocolVersion, ", got protocol version: ", read_version); } @@ -213,7 +213,7 @@ void ClientConnection::ProcessMessageHeader(const std::error_code& ec) { // If there was no error, make sure the protocol version matches. if (read_version_ != kPlasmaProtocolVersion) { - status = Status::ProtocolError( + status = Status::IOError( "Expected Plasma message protocol version: ", kPlasmaProtocolVersion, ", got protocol version: ", read_version_); ProcessError(status); diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index e59db24eddd..71159c02d46 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -564,13 +564,16 @@ int PlasmaStore::AbortObject(const ObjectID& object_id, ARROW_CHECK(entry != nullptr) << "To abort an object it must be in the object table."; ARROW_CHECK(entry->state != ObjectState::PLASMA_SEALED) << "To abort an object it must not have been sealed."; - if (client->ObjectIDExists(object_id)) { + auto it = client->object_ids.find(object_id); + if (it == client->object_ids.end()) { + // If the client requesting the abort is not the creator, do not + // perform the abort. + return 0; + } else { // The client requesting the abort is the creator. Free the object. EraseFromObjectTable(object_id); client->object_ids.erase(it); return 1; - } else { - return 0; } } From c8bbf05d97d1f4d52f863b8d848f27ede3884ce1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 5 Jul 2019 18:13:39 -0700 Subject: [PATCH 43/52] fix --- cpp/src/plasma/test/serialization_tests.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/plasma/test/serialization_tests.cc b/cpp/src/plasma/test/serialization_tests.cc index 8270f09775e..a8463dcb271 100644 --- a/cpp/src/plasma/test/serialization_tests.cc +++ b/cpp/src/plasma/test/serialization_tests.cc @@ -225,7 +225,6 @@ TEST_F(TestPlasmaSerialization, ReleaseReply) { } TEST_F(TestPlasmaSerialization, DeleteRequest) { - int fd = CreateTemporaryFile(); ObjectID object_id1 = random_object_id(); ASSERT_OK(SendDeleteRequest(client_, std::vector{object_id1})); std::vector data; From d60f69fae3d2a0c9f9d71f24c01eb71547564fbf Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 5 Jul 2019 18:16:37 -0700 Subject: [PATCH 44/52] update --- cpp/src/plasma/protocol.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/plasma/protocol.cc b/cpp/src/plasma/protocol.cc index 520e11b54ca..e768174a3f8 100644 --- a/cpp/src/plasma/protocol.cc +++ b/cpp/src/plasma/protocol.cc @@ -602,6 +602,7 @@ Status SendGetReply(const std::shared_ptr& client, fbb.CreateVector(arrow::util::MakeNonNull(store_fds.data()), store_fds.size()), fbb.CreateVector(arrow::util::MakeNonNull(mmap_sizes.data()), mmap_sizes.size()), fbb.CreateVector(arrow::util::MakeNonNull(handles.data()), handles.size())); + fbb.Finish(message); return PlasmaSend(client, MessageType::PlasmaGetReply, &fbb); } From 19f0a351317bf391404916945dda46fe6a5a81d8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 5 Jul 2019 18:19:40 -0700 Subject: [PATCH 45/52] update --- cpp/src/plasma/protocol.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/plasma/protocol.cc b/cpp/src/plasma/protocol.cc index e768174a3f8..9fe340b9e41 100644 --- a/cpp/src/plasma/protocol.cc +++ b/cpp/src/plasma/protocol.cc @@ -376,6 +376,7 @@ Status SendDeleteReply(const std::shared_ptr& client, fbb.CreateVector( arrow::util::MakeNonNull(reinterpret_cast(errors.data())), object_ids.size())); + fbb.Finish(message); return PlasmaSend(client, MessageType::PlasmaDeleteReply, &fbb); } @@ -465,6 +466,7 @@ Status SendListReply(const std::shared_ptr& client, auto message = fb::CreatePlasmaListReply( fbb, fbb.CreateVector(arrow::util::MakeNonNull(object_infos.data()), object_infos.size())); + fbb.Finish(message); return PlasmaSend(client, MessageType::PlasmaListReply, &fbb); } From fd3a73d9b4fb7a74962cb5a9128e13605f930a70 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 5 Jul 2019 19:22:17 -0700 Subject: [PATCH 46/52] update --- cpp/src/plasma/test/client_tests.cc | 250 ++++++++++---------- cpp/src/plasma/test/external_store_tests.cc | 16 +- 2 files changed, 133 insertions(+), 133 deletions(-) diff --git a/cpp/src/plasma/test/client_tests.cc b/cpp/src/plasma/test/client_tests.cc index 2a8d27648ce..01467f29487 100644 --- a/cpp/src/plasma/test/client_tests.cc +++ b/cpp/src/plasma/test/client_tests.cc @@ -50,7 +50,7 @@ class TestPlasmaStore : public ::testing::Test { // stdout of the object store. Consider changing that. void SetUp() { - ARROW_CHECK_OK(TemporaryDir::Make("cli-test-", &temp_dir_)); + ASSERT_OK(TemporaryDir::Make("cli-test-", &temp_dir_)); store_socket_name_ = temp_dir_->path().ToString() + "store"; std::string plasma_directory = @@ -59,13 +59,13 @@ class TestPlasmaStore : public ::testing::Test { plasma_directory + "/plasma_store_server -m 10000000 -s " + store_socket_name_ + " 1> /dev/null 2> /dev/null & " + "echo $! > " + store_socket_name_ + ".pid"; PLASMA_CHECK_SYSTEM(system(plasma_command.c_str())); - ARROW_CHECK_OK(client_.Connect(store_socket_name_, "")); - ARROW_CHECK_OK(client2_.Connect(store_socket_name_, "")); + ASSERT_OK(client_.Connect(store_socket_name_, "")); + ASSERT_OK(client2_.Connect(store_socket_name_, "")); } virtual void TearDown() { - ARROW_CHECK_OK(client_.Disconnect()); - ARROW_CHECK_OK(client2_.Disconnect()); + ASSERT_OK(client_.Disconnect()); + ASSERT_OK(client2_.Disconnect()); // Kill plasma_store process that we started #ifdef COVERAGE_BUILD // Ask plasma_store to exit gracefully and give it time to write out @@ -84,14 +84,14 @@ class TestPlasmaStore : public ::testing::Test { const std::vector& metadata, const std::vector& data, bool release = true) { std::shared_ptr data_buffer; - ARROW_CHECK_OK(client.Create(object_id, data.size(), &metadata[0], metadata.size(), + ASSERT_OK(client.Create(object_id, data.size(), &metadata[0], metadata.size(), &data_buffer)); for (size_t i = 0; i < data.size(); i++) { data_buffer->mutable_data()[i] = data[i]; } - ARROW_CHECK_OK(client.Seal(object_id)); + ASSERT_OK(client.Seal(object_id)); if (release) { - ARROW_CHECK_OK(client.Release(object_id)); + ASSERT_OK(client.Release(object_id)); } } @@ -105,8 +105,8 @@ class TestPlasmaStore : public ::testing::Test { TEST_F(TestPlasmaStore, NewSubscriberTest) { PlasmaClient local_client, local_client2; - ARROW_CHECK_OK(local_client.Connect(store_socket_name_, "")); - ARROW_CHECK_OK(local_client2.Connect(store_socket_name_, "")); + ASSERT_OK(local_client.Connect(store_socket_name_, "")); + ASSERT_OK(local_client2.Connect(store_socket_name_, "")); ObjectID object_id = random_object_id(); @@ -116,34 +116,34 @@ TEST_F(TestPlasmaStore, NewSubscriberTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK( + ASSERT_OK( local_client.Create(object_id, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(local_client.Seal(object_id)); + ASSERT_OK(local_client.Seal(object_id)); // Test that new subscriber client2 can receive notifications about existing objects. - ARROW_CHECK_OK(local_client2.Subscribe()); + ASSERT_OK(local_client2.Subscribe()); ObjectID object_id2 = random_object_id(); int64_t data_size2 = 0; int64_t metadata_size2 = 0; - ARROW_CHECK_OK( + ASSERT_OK( local_client2.GetNotification(&object_id2, &data_size2, &metadata_size2)); ASSERT_EQ(object_id, object_id2); ASSERT_EQ(data_size, data_size2); ASSERT_EQ(metadata_size, metadata_size2); // Delete the object. - ARROW_CHECK_OK(local_client.Release(object_id)); - ARROW_CHECK_OK(local_client.Delete(object_id)); + ASSERT_OK(local_client.Release(object_id)); + ASSERT_OK(local_client.Delete(object_id)); - ARROW_CHECK_OK( + ASSERT_OK( local_client2.GetNotification(&object_id2, &data_size2, &metadata_size2)); ASSERT_EQ(object_id, object_id2); ASSERT_EQ(-1, data_size2); ASSERT_EQ(-1, metadata_size2); - ARROW_CHECK_OK(local_client2.Disconnect()); - ARROW_CHECK_OK(local_client.Disconnect()); + ASSERT_OK(local_client2.Disconnect()); + ASSERT_OK(local_client.Disconnect()); } TEST_F(TestPlasmaStore, SealErrorsTest) { @@ -159,7 +159,7 @@ TEST_F(TestPlasmaStore, SealErrorsTest) { // Trying to seal it again. result = client_.Seal(object_id); ASSERT_TRUE(IsPlasmaObjectAlreadySealed(result)); - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Release(object_id)); } TEST_F(TestPlasmaStore, DeleteTest) { @@ -167,7 +167,7 @@ TEST_F(TestPlasmaStore, DeleteTest) { // Test for deleting non-existance object. Status result = client_.Delete(object_id); - ARROW_CHECK_OK(result); + ASSERT_OK(result); // Test for the object being in local Plasma store. // First create object. @@ -175,20 +175,20 @@ TEST_F(TestPlasmaStore, DeleteTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client_.Seal(object_id)); + ASSERT_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Seal(object_id)); result = client_.Delete(object_id); - ARROW_CHECK_OK(result); + ASSERT_OK(result); bool has_object = false; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Release(object_id)); // object_id is marked as to-be-deleted, when it is not in use, it will be deleted. - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); - ARROW_CHECK_OK(client_.Delete(object_id)); + ASSERT_OK(client_.Delete(object_id)); } TEST_F(TestPlasmaStore, DeleteObjectsTest) { @@ -197,31 +197,31 @@ TEST_F(TestPlasmaStore, DeleteObjectsTest) { // Test for deleting non-existance object. Status result = client_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(result); + ASSERT_OK(result); // Test for the object being in local Plasma store. // First create object. int64_t data_size = 100; uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client_.Seal(object_id1)); - ARROW_CHECK_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client_.Seal(object_id2)); + ASSERT_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Seal(object_id1)); + ASSERT_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Seal(object_id2)); // Release the ref count of Create function. - ARROW_CHECK_OK(client_.Release(object_id1)); - ARROW_CHECK_OK(client_.Release(object_id2)); + ASSERT_OK(client_.Release(object_id1)); + ASSERT_OK(client_.Release(object_id2)); // Increase the ref count by calling Get using client2_. std::vector object_buffers; - ARROW_CHECK_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers)); + ASSERT_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers)); // Objects are still used by client2_. result = client_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(result); + ASSERT_OK(result); // The object is used and it should not be deleted right now. bool has_object = false; - ARROW_CHECK_OK(client_.Contains(object_id1, &has_object)); + ASSERT_OK(client_.Contains(object_id1, &has_object)); ASSERT_TRUE(has_object); - ARROW_CHECK_OK(client_.Contains(object_id2, &has_object)); + ASSERT_OK(client_.Contains(object_id2, &has_object)); ASSERT_TRUE(has_object); // Decrease the ref count by deleting the PlasmaBuffer (in ObjectBuffer). // client2_ won't send the release request immediately because the trigger @@ -229,9 +229,9 @@ TEST_F(TestPlasmaStore, DeleteObjectsTest) { object_buffers.clear(); // Delete the objects. result = client2_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(client_.Contains(object_id1, &has_object)); + ASSERT_OK(client_.Contains(object_id1, &has_object)); ASSERT_FALSE(has_object); - ARROW_CHECK_OK(client_.Contains(object_id2, &has_object)); + ASSERT_OK(client_.Contains(object_id2, &has_object)); ASSERT_FALSE(has_object); } @@ -240,7 +240,7 @@ TEST_F(TestPlasmaStore, ContainsTest) { // Test for object non-existence. bool has_object; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); // Test for the object being in local Plasma store. @@ -248,8 +248,8 @@ TEST_F(TestPlasmaStore, ContainsTest) { std::vector data(100, 0); CreateObject(client_, object_id, {42}, data); std::vector object_buffers; - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); } @@ -259,7 +259,7 @@ TEST_F(TestPlasmaStore, GetTest) { ObjectID object_id = random_object_id(); // Test for object non-existence. - ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, 0, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_FALSE(object_buffers[0].metadata); ASSERT_FALSE(object_buffers[0].data); @@ -272,7 +272,7 @@ TEST_F(TestPlasmaStore, GetTest) { EXPECT_FALSE(client_.IsInUse(object_id)); object_buffers.clear(); - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_EQ(object_buffers[0].device_num, 0); AssertObjectBufferEqual(object_buffers[0], {42}, {3, 5, 6, 7, 9}); @@ -295,7 +295,7 @@ TEST_F(TestPlasmaStore, LegacyGetTest) { ObjectBuffer object_buffer; // Test for object non-existence. - ARROW_CHECK_OK(client_.Get(&object_id, 1, 0, &object_buffer)); + ASSERT_OK(client_.Get(&object_id, 1, 0, &object_buffer)); ASSERT_FALSE(object_buffer.metadata); ASSERT_FALSE(object_buffer.data); EXPECT_FALSE(client_.IsInUse(object_id)); @@ -305,12 +305,12 @@ TEST_F(TestPlasmaStore, LegacyGetTest) { CreateObject(client_, object_id, {42}, data); EXPECT_FALSE(client_.IsInUse(object_id)); - ARROW_CHECK_OK(client_.Get(&object_id, 1, -1, &object_buffer)); + ASSERT_OK(client_.Get(&object_id, 1, -1, &object_buffer)); AssertObjectBufferEqual(object_buffer, {42}, {3, 5, 6, 7, 9}); } // Object needs releasing manually EXPECT_TRUE(client_.IsInUse(object_id)); - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Release(object_id)); EXPECT_FALSE(client_.IsInUse(object_id)); } @@ -324,15 +324,15 @@ TEST_F(TestPlasmaStore, MultipleGetTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data)); data->mutable_data()[0] = 1; - ARROW_CHECK_OK(client_.Seal(object_id1)); + ASSERT_OK(client_.Seal(object_id1)); - ARROW_CHECK_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data)); data->mutable_data()[0] = 2; - ARROW_CHECK_OK(client_.Seal(object_id2)); + ASSERT_OK(client_.Seal(object_id2)); - ARROW_CHECK_OK(client_.Get(object_ids, -1, &object_buffers)); + ASSERT_OK(client_.Get(object_ids, -1, &object_buffers)); ASSERT_EQ(object_buffers[0].data->data()[0], 1); ASSERT_EQ(object_buffers[1].data->data()[0], 2); } @@ -342,7 +342,7 @@ TEST_F(TestPlasmaStore, AbortTest) { std::vector object_buffers; // Test for object non-existence. - ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, 0, &object_buffers)); ASSERT_FALSE(object_buffers[0].data); // Test object abort. @@ -352,7 +352,7 @@ TEST_F(TestPlasmaStore, AbortTest) { int64_t metadata_size = sizeof(metadata); std::shared_ptr data; uint8_t* data_ptr; - ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); data_ptr = data->mutable_data(); // Write some data. for (int64_t i = 0; i < data_size / 2; i++) { @@ -362,21 +362,21 @@ TEST_F(TestPlasmaStore, AbortTest) { Status status = client_.Abort(object_id); ASSERT_TRUE(status.IsInvalid()); // Release, then abort. - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Release(object_id)); EXPECT_TRUE(client_.IsInUse(object_id)); - ARROW_CHECK_OK(client_.Abort(object_id)); + ASSERT_OK(client_.Abort(object_id)); EXPECT_FALSE(client_.IsInUse(object_id)); // Test for object non-existence after the abort. - ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, 0, &object_buffers)); ASSERT_FALSE(object_buffers[0].data); // Create the object successfully this time. CreateObject(client_, object_id, {42, 43}, {1, 2, 3, 4, 5}); // Test that we can get the object. - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); AssertObjectBufferEqual(object_buffers[0], {42, 43}, {1, 2, 3, 4, 5}); } @@ -387,7 +387,7 @@ TEST_F(TestPlasmaStore, OneIdCreateRepeatedlyTest) { std::vector object_buffers; // Test for object non-existence. - ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, 0, &object_buffers)); ASSERT_FALSE(object_buffers[0].data); int64_t data_size = 20; @@ -397,18 +397,18 @@ TEST_F(TestPlasmaStore, OneIdCreateRepeatedlyTest) { // Test the sequence: create -> release -> abort -> ... for (int64_t i = 0; i < loop_times; i++) { std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client_.Release(object_id)); - ARROW_CHECK_OK(client_.Abort(object_id)); + ASSERT_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Release(object_id)); + ASSERT_OK(client_.Abort(object_id)); } // Test the sequence: create -> seal -> release -> delete -> ... for (int64_t i = 0; i < loop_times; i++) { std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client_.Seal(object_id)); - ARROW_CHECK_OK(client_.Release(object_id)); - ARROW_CHECK_OK(client_.Delete(object_id)); + ASSERT_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Seal(object_id)); + ASSERT_OK(client_.Release(object_id)); + ASSERT_OK(client_.Delete(object_id)); } } @@ -418,7 +418,7 @@ TEST_F(TestPlasmaStore, MultipleClientTest) { // Test for object non-existence on the first client. bool has_object; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); // Test for the object being in local Plasma store. @@ -427,25 +427,25 @@ TEST_F(TestPlasmaStore, MultipleClientTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data)); - ARROW_CHECK_OK(client2_.Seal(object_id)); + ASSERT_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client2_.Seal(object_id)); // Test that the first client can get the object. - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); ASSERT_TRUE(object_buffers[0].data); - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); // Test that one client disconnecting does not interfere with the other. // First create object on the second client. object_id = random_object_id(); - ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data)); // Disconnect the first client. - ARROW_CHECK_OK(client_.Disconnect()); + ASSERT_OK(client_.Disconnect()); // Test that the second client can seal and get the created object. - ARROW_CHECK_OK(client2_.Seal(object_id)); - ARROW_CHECK_OK(client2_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client2_.Seal(object_id)); + ASSERT_OK(client2_.Get({object_id}, -1, &object_buffers)); ASSERT_TRUE(object_buffers[0].data); - ARROW_CHECK_OK(client2_.Contains(object_id, &has_object)); + ASSERT_OK(client2_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); } @@ -459,7 +459,7 @@ TEST_F(TestPlasmaStore, ManyObjectTest) { // Test for object non-existence on the first client. bool has_object; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); // Test for the object being in local Plasma store. @@ -468,29 +468,29 @@ TEST_F(TestPlasmaStore, ManyObjectTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); + ASSERT_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data)); if (i % 3 == 0) { // Seal one third of the objects. - ARROW_CHECK_OK(client_.Seal(object_id)); + ASSERT_OK(client_.Seal(object_id)); // Test that the first client can get the object. - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); } else if (i % 3 == 1) { // Abort one third of the objects. - ARROW_CHECK_OK(client_.Release(object_id)); - ARROW_CHECK_OK(client_.Abort(object_id)); + ASSERT_OK(client_.Release(object_id)); + ASSERT_OK(client_.Abort(object_id)); } } // Disconnect the first client. All unsealed objects should be aborted. - ARROW_CHECK_OK(client_.Disconnect()); + ASSERT_OK(client_.Disconnect()); // Check that the second client can query the object store for the first // client's objects. int i = 0; for (auto const& object_id : object_ids) { bool has_object; - ARROW_CHECK_OK(client2_.Contains(object_id, &has_object)); + ASSERT_OK(client2_.Contains(object_id, &has_object)); if (i % 3 == 0) { // The first third should be sealed. ASSERT_TRUE(has_object); @@ -514,13 +514,13 @@ void AssertCudaRead(const std::shared_ptr& buffer, std::shared_ptr gpu_buffer; const size_t data_size = expected_data.size(); - ARROW_CHECK_OK(CudaBuffer::FromBuffer(buffer, &gpu_buffer)); + ASSERT_OK(CudaBuffer::FromBuffer(buffer, &gpu_buffer)); ASSERT_EQ(gpu_buffer->size(), data_size); CudaBufferReader reader(gpu_buffer); std::vector read_data(data_size); int64_t read_data_size; - ARROW_CHECK_OK(reader.Read(data_size, &read_data_size, read_data.data())); + ASSERT_OK(reader.Read(data_size, &read_data_size, read_data.data())); ASSERT_EQ(read_data_size, data_size); for (size_t i = 0; i < data_size; i++) { @@ -535,7 +535,7 @@ TEST_F(TestPlasmaStore, GetGPUTest) { std::vector object_buffers; // Test for object non-existence. - ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, 0, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_FALSE(object_buffers[0].data); @@ -547,15 +547,15 @@ TEST_F(TestPlasmaStore, GetGPUTest) { int64_t metadata_size = sizeof(metadata); std::shared_ptr data_buffer; std::shared_ptr gpu_buffer; - ARROW_CHECK_OK( + ASSERT_OK( client_.Create(object_id, data_size, metadata, metadata_size, &data_buffer, 1)); - ARROW_CHECK_OK(CudaBuffer::FromBuffer(data_buffer, &gpu_buffer)); + ASSERT_OK(CudaBuffer::FromBuffer(data_buffer, &gpu_buffer)); CudaBufferWriter writer(gpu_buffer); - ARROW_CHECK_OK(writer.Write(data, data_size)); - ARROW_CHECK_OK(client_.Seal(object_id)); + ASSERT_OK(writer.Write(data, data_size)); + ASSERT_OK(client_.Seal(object_id)); object_buffers.clear(); - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_EQ(object_buffers[0].device_num, 1); // Check data @@ -570,34 +570,34 @@ TEST_F(TestPlasmaStore, DeleteObjectsGPUTest) { // Test for deleting non-existance object. Status result = client_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(result); + ASSERT_OK(result); // Test for the object being in local Plasma store. // First create object. int64_t data_size = 100; uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK( + ASSERT_OK( client_.Create(object_id1, data_size, metadata, metadata_size, &data, 1)); - ARROW_CHECK_OK(client_.Seal(object_id1)); - ARROW_CHECK_OK( + ASSERT_OK(client_.Seal(object_id1)); + ASSERT_OK( client_.Create(object_id2, data_size, metadata, metadata_size, &data, 1)); - ARROW_CHECK_OK(client_.Seal(object_id2)); + ASSERT_OK(client_.Seal(object_id2)); // Release the ref count of Create function. data = nullptr; - ARROW_CHECK_OK(client_.Release(object_id1)); - ARROW_CHECK_OK(client_.Release(object_id2)); + ASSERT_OK(client_.Release(object_id1)); + ASSERT_OK(client_.Release(object_id2)); // Increase the ref count by calling Get using client2_. std::vector object_buffers; - ARROW_CHECK_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers)); + ASSERT_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers)); // Objects are still used by client2_. result = client_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(result); + ASSERT_OK(result); // The object is used and it should not be deleted right now. bool has_object = false; - ARROW_CHECK_OK(client_.Contains(object_id1, &has_object)); + ASSERT_OK(client_.Contains(object_id1, &has_object)); ASSERT_TRUE(has_object); - ARROW_CHECK_OK(client_.Contains(object_id2, &has_object)); + ASSERT_OK(client_.Contains(object_id2, &has_object)); ASSERT_TRUE(has_object); // Decrease the ref count by deleting the PlasmaBuffer (in ObjectBuffer). // client2_ won't send the release request immediately because the trigger @@ -605,9 +605,9 @@ TEST_F(TestPlasmaStore, DeleteObjectsGPUTest) { object_buffers.clear(); // Delete the objects. result = client2_.Delete(std::vector{object_id1, object_id2}); - ARROW_CHECK_OK(client_.Contains(object_id1, &has_object)); + ASSERT_OK(client_.Contains(object_id1, &has_object)); ASSERT_FALSE(has_object); - ARROW_CHECK_OK(client_.Contains(object_id2, &has_object)); + ASSERT_OK(client_.Contains(object_id2, &has_object)); ASSERT_FALSE(has_object); } @@ -624,27 +624,27 @@ TEST_F(TestPlasmaStore, RepeatlyCreateGPUTest) { ObjectID& object_id = object_ids[i]; std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, 0, 0, &data, 1)); - ARROW_CHECK_OK(client_.Seal(object_id)); - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Create(object_id, data_size, 0, 0, &data, 1)); + ASSERT_OK(client_.Seal(object_id)); + ASSERT_OK(client_.Release(object_id)); } // delete and create again for (int64_t i = 0; i < loop_times; i++) { ObjectID& object_id = object_ids[i % object_num]; - ARROW_CHECK_OK(client_.Delete(object_id)); + ASSERT_OK(client_.Delete(object_id)); std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id, data_size, 0, 0, &data, 1)); - ARROW_CHECK_OK(client_.Seal(object_id)); + ASSERT_OK(client_.Create(object_id, data_size, 0, 0, &data, 1)); + ASSERT_OK(client_.Seal(object_id)); data = nullptr; - ARROW_CHECK_OK(client_.Release(object_id)); + ASSERT_OK(client_.Release(object_id)); } // delete all - ARROW_CHECK_OK(client_.Delete(object_ids)); + ASSERT_OK(client_.Delete(object_ids)); } TEST_F(TestPlasmaStore, MultipleClientGPUTest) { @@ -653,7 +653,7 @@ TEST_F(TestPlasmaStore, MultipleClientGPUTest) { // Test for object non-existence on the first client. bool has_object; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); // Test for the object being in local Plasma store. @@ -662,27 +662,27 @@ TEST_F(TestPlasmaStore, MultipleClientGPUTest) { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - ARROW_CHECK_OK( + ASSERT_OK( client2_.Create(object_id, data_size, metadata, metadata_size, &data, 1)); - ARROW_CHECK_OK(client2_.Seal(object_id)); + ASSERT_OK(client2_.Seal(object_id)); // Test that the first client can get the object. - ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers)); - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); // Test that one client disconnecting does not interfere with the other. // First create object on the second client. object_id = random_object_id(); - ARROW_CHECK_OK( + ASSERT_OK( client2_.Create(object_id, data_size, metadata, metadata_size, &data, 1)); // Disconnect the first client. - ARROW_CHECK_OK(client_.Disconnect()); + ASSERT_OK(client_.Disconnect()); // Test that the second client can seal and get the created object. - ARROW_CHECK_OK(client2_.Seal(object_id)); + ASSERT_OK(client2_.Seal(object_id)); object_buffers.clear(); - ARROW_CHECK_OK(client2_.Contains(object_id, &has_object)); + ASSERT_OK(client2_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); - ARROW_CHECK_OK(client2_.Get({object_id}, -1, &object_buffers)); + ASSERT_OK(client2_.Get({object_id}, -1, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_EQ(object_buffers[0].device_num, 1); AssertCudaRead(object_buffers[0].metadata, {5}); diff --git a/cpp/src/plasma/test/external_store_tests.cc b/cpp/src/plasma/test/external_store_tests.cc index 2b7d3217352..467d11cfa1c 100644 --- a/cpp/src/plasma/test/external_store_tests.cc +++ b/cpp/src/plasma/test/external_store_tests.cc @@ -54,7 +54,7 @@ class TestPlasmaStoreWithExternal : public ::testing::Test { // TODO(pcm): At the moment, stdout of the test gets mixed up with // stdout of the object store. Consider changing that. void SetUp() override { - ARROW_CHECK_OK(TemporaryDir::Make("ext-test-", &temp_dir_)); + ASSERT_OK(TemporaryDir::Make("ext-test-", &temp_dir_)); store_socket_name_ = temp_dir_->path().ToString() + "store"; std::string plasma_directory = @@ -65,11 +65,11 @@ class TestPlasmaStoreWithExternal : public ::testing::Test { " 1> /tmp/log.stdout 2> /tmp/log.stderr & " + "echo $! > " + store_socket_name_ + ".pid"; PLASMA_CHECK_SYSTEM(system(plasma_command.c_str())); - ARROW_CHECK_OK(client_.Connect(store_socket_name_, "")); + ASSERT_OK(client_.Connect(store_socket_name_, "")); } void TearDown() override { - ARROW_CHECK_OK(client_.Disconnect()); + ASSERT_OK(client_.Disconnect()); // Kill plasma_store process that we started #ifdef COVERAGE_BUILD // Ask plasma_store to exit gracefully and give it time to write out @@ -100,14 +100,14 @@ TEST_F(TestPlasmaStoreWithExternal, EvictionTest) { // Test for object non-existence. bool has_object; - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_FALSE(has_object); // Test for the object being in local Plasma store. // Create and seal the object. - ARROW_CHECK_OK(client_.CreateAndSeal(object_id, data, metadata)); + ASSERT_OK(client_.CreateAndSeal(object_id, data, metadata)); // Test that the client can get the object. - ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_OK(client_.Contains(object_id, &has_object)); ASSERT_TRUE(has_object); } @@ -118,7 +118,7 @@ TEST_F(TestPlasmaStoreWithExternal, EvictionTest) { // external store on failure. This should succeed to fetch the object. // However, it may evict the next few objects. std::vector object_buffers; - ARROW_CHECK_OK(client_.Get({object_ids[i]}, -1, &object_buffers)); + ASSERT_OK(client_.Get({object_ids[i]}, -1, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_EQ(object_buffers[0].device_num, 0); ASSERT_TRUE(object_buffers[0].data); @@ -127,7 +127,7 @@ TEST_F(TestPlasmaStoreWithExternal, EvictionTest) { // Make sure we still cannot fetch objects that do not exist std::vector object_buffers; - ARROW_CHECK_OK(client_.Get({random_object_id()}, 100, &object_buffers)); + ASSERT_OK(client_.Get({random_object_id()}, 100, &object_buffers)); ASSERT_EQ(object_buffers.size(), 1); ASSERT_EQ(object_buffers[0].device_num, 0); ASSERT_EQ(object_buffers[0].data, nullptr); From b6d7461838f840335538d53af9603f125ddf910d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 13 Jul 2019 18:25:20 -0700 Subject: [PATCH 47/52] update --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 27 ---------------- cpp/src/plasma/CMakeLists.txt | 2 -- cpp/src/plasma/io/basic_connection.cc | 36 ++++++++++----------- cpp/src/plasma/io/basic_connection.h | 19 +++++++---- cpp/src/plasma/io/connection.cc | 8 ++--- cpp/src/plasma/io/connection.h | 6 ++-- cpp/src/plasma/store.cc | 8 ++--- cpp/src/plasma/store.h | 2 +- 8 files changed, 42 insertions(+), 66 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 2f25d866fa4..2c4a4971dea 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -2374,33 +2374,6 @@ if(ARROW_ORC) message(STATUS "Found ORC headers: ${ORC_INCLUDE_DIR}") endif() -# ---------------------------------------------------------------------- -# Plasma - -if(ARROW_PLASMA) - externalproject_add(asio_ep - URL - https://github.com/chriskohlhoff/asio/archive/asio-1-12-2.zip - CONFIGURE_COMMAND - "" # No autogen since we use asio in header-only way. - BUILD_IN_SOURCE - 1 - BUILD_COMMAND - "" - INSTALL_COMMAND - cmake - -E - copy - asio/include/asio.hpp - ${PROJECT_BINARY_DIR}/src - && - cmake - -E - copy_directory - asio/include/asio - ${PROJECT_BINARY_DIR}/src/asio/) -endif() - # Write out the package configurations. configure_file("src/arrow/util/config.h.cmake" "src/arrow/util/config.h") diff --git a/cpp/src/plasma/CMakeLists.txt b/cpp/src/plasma/CMakeLists.txt index ff5532964fc..48356184784 100644 --- a/cpp/src/plasma/CMakeLists.txt +++ b/cpp/src/plasma/CMakeLists.txt @@ -113,7 +113,6 @@ add_arrow_lib(plasma PLASMA_LIBRARIES DEPENDENCIES gen_plasma_fbs - asio_ep SHARED_LINK_FLAGS ${PLASMA_SHARED_LINK_FLAGS} SHARED_LINK_LIBS @@ -158,7 +157,6 @@ else() endif() add_dependencies(plasma plasma_store_server) -add_dependencies(plasma_store_server asio_ep) if(ARROW_RPATH_ORIGIN) if(APPLE) diff --git a/cpp/src/plasma/io/basic_connection.cc b/cpp/src/plasma/io/basic_connection.cc index 1ba9571733b..15ae7a05385 100644 --- a/cpp/src/plasma/io/basic_connection.cc +++ b/cpp/src/plasma/io/basic_connection.cc @@ -31,14 +31,14 @@ namespace io { /// \param socket The socket to connect. /// \param socket_name The name/path of the socket. /// \return Status. -std::error_code UnixDomainSocketConnect(asio::local::stream_protocol::socket& socket, +error_code UnixDomainSocketConnect(asio::local::stream_protocol::socket& socket, const std::string& socket_name) { asio::local::stream_protocol::endpoint endpoint(socket_name); - std::error_code ec; + error_code ec; socket.connect(endpoint, ec); if (ec) { // Close the socket if the connect failed. - std::error_code close_error; + error_code close_error; socket.close(close_error); } return ec; @@ -52,7 +52,7 @@ PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& #ifndef _WIN32 asio::basic_stream_socket socket(io_context); for (int i = 0; i < num_retries; i++) { - std::error_code ec = UnixDomainSocketConnect(socket, name); + error_code ec = UnixDomainSocketConnect(socket, name); if (!ec) { break; } @@ -92,13 +92,13 @@ Connection::~Connection() { // If there are any pending messages, invoke their callbacks with an IOError status. for (const auto& write_buffer : async_write_queue_) { write_buffer->Handle( - std::error_code(static_cast(std::errc::io_error), std::system_category())); + error_code(static_cast(boost::system::errc::io_error), boost::system::system_category())); } } template -std::error_code Connection::ReadBuffer(const asio::mutable_buffer& buffer) { - std::error_code ec; +error_code Connection::ReadBuffer(const asio::mutable_buffer& buffer) { + error_code ec; // Loop until all bytes are read while handling interrupts. uint64_t bytes_remaining = asio::buffer_size(buffer); uint64_t position = 0; @@ -113,26 +113,26 @@ std::error_code Connection::ReadBuffer(const asio::mutable_buffer& buffer) { return ec; } } - return std::error_code(); + return error_code(); } template -std::error_code Connection::ReadBuffer( +error_code Connection::ReadBuffer( const std::vector& buffer) { // Loop until all bytes are read while handling interrupts. for (const auto& b : buffer) { auto ec = ReadBuffer(b); if (ec) return ec; } - return std::error_code(); + return error_code(); } /// Write a buffer to this connection. /// /// \param buffer The buffer. template -std::error_code Connection::WriteBuffer(const asio::const_buffer& buffer) { - std::error_code error; +error_code Connection::WriteBuffer(const asio::const_buffer& buffer) { + error_code error; // Loop until all bytes are written while handling interrupts. // When profiling with pprof, unhandled interrupts were being sent by the profiler to // the raylet process, which was causing synchronous reads and writes to fail. @@ -149,13 +149,13 @@ std::error_code Connection::WriteBuffer(const asio::const_buffer& buffer) { return error; } } - return std::error_code(); + return error_code(); } template -std::error_code Connection::WriteBuffer( +error_code Connection::WriteBuffer( const std::vector& buffer) { - std::error_code error; + error_code error; // Loop until all bytes are written while handling interrupts. // When profiling with pprof, unhandled interrupts were being sent by the profiler to // the raylet process, which was causing synchronous reads and writes to fail. @@ -165,7 +165,7 @@ std::error_code Connection::WriteBuffer( return error; } } - return std::error_code(); + return error_code(); } template @@ -185,7 +185,7 @@ void Connection::WriteBufferAsync(std::unique_ptr write_buf // Shuts down socket for this connection. template void Connection::Close() { - std::error_code ec; + error_code ec; stream_.close(ec); } @@ -221,7 +221,7 @@ void Connection::DoAsyncWrites() { // Ensure lambda holds a reference to this. auto this_ptr = this->shared_from_this(); asio::async_write(stream_, message_buffers, - [this, this_ptr, num_messages](const std::error_code& ec, + [this, this_ptr, num_messages](const error_code& ec, size_t bytes_transferred) { bytes_written_ += bytes_transferred; bool close_connection = false; diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index c121dedefcf..b1fc00abf51 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -29,7 +29,12 @@ #include #include -#include "asio.hpp" // NOLINT +#include +#include + +namespace asio = boost::asio; + +using error_code = boost::system::error_code; namespace plasma { namespace io { @@ -40,7 +45,7 @@ enum class AsyncWriteCallbackCode { UNKNOWN_ERROR, }; -using AsyncWriteCallback = std::function; +using AsyncWriteCallback = std::function; // TODO(suquark): Change it according to the platform. using PlasmaStream = asio::basic_stream_socket; using PlasmaAcceptor = asio::local::stream_protocol::acceptor; @@ -55,7 +60,7 @@ PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& struct AsyncWriteBuffer { virtual void ToBuffers(std::vector& message_buffers) = 0; virtual ~AsyncWriteBuffer() {} - inline AsyncWriteCallbackCode Handle(const std::error_code& ec) { return handler_(ec); } + inline AsyncWriteCallbackCode Handle(const error_code& ec) { return handler_(ec); } protected: AsyncWriteCallback handler_; @@ -71,22 +76,22 @@ class Connection : public std::enable_shared_from_this> { /// Read a buffer from this connection. /// /// \param buffer The output buffer. - std::error_code ReadBuffer(const asio::mutable_buffer& buffer); + error_code ReadBuffer(const asio::mutable_buffer& buffer); /// Read buffers from this connection. /// /// \param buffer The output vector of buffers. - std::error_code ReadBuffer(const std::vector& buffer); + error_code ReadBuffer(const std::vector& buffer); /// Write a buffer to this connection. /// /// \param buffer The buffer. - std::error_code WriteBuffer(const asio::const_buffer& buffer); + error_code WriteBuffer(const asio::const_buffer& buffer); /// Write buffers to this connection. /// /// \param buffer The vector of buffers. - std::error_code WriteBuffer(const std::vector& buffer); + error_code WriteBuffer(const std::vector& buffer); /// Write buffers to this connection async. /// diff --git a/cpp/src/plasma/io/connection.cc b/cpp/src/plasma/io/connection.cc index 348e382877c..97300a7a777 100644 --- a/cpp/src/plasma/io/connection.cc +++ b/cpp/src/plasma/io/connection.cc @@ -36,7 +36,7 @@ namespace io { using flatbuf::MessageType; -Status asio_to_arrow_status(const std::error_code& ec) { +Status asio_to_arrow_status(const error_code& ec) { if (!ec) { return Status::OK(); } @@ -203,7 +203,7 @@ void ClientConnection::ProcessMessages() { std::placeholders::_1)); // Ignore byte_transferred } -void ClientConnection::ProcessMessageHeader(const std::error_code& ec) { +void ClientConnection::ProcessMessageHeader(const error_code& ec) { auto status = asio_to_arrow_status(ec); if (!status.ok()) { // If there was an error, disconnect the client. @@ -228,7 +228,7 @@ void ClientConnection::ProcessMessageHeader(const std::error_code& ec) { std::placeholders::_1)); } -void ClientConnection::ProcessMessageBody(const std::error_code& ec) { +void ClientConnection::ProcessMessageBody(const error_code& ec) { auto status = asio_to_arrow_status(ec); if (!status.ok()) { // If there was an error, disconnect the client. @@ -283,7 +283,7 @@ struct AsyncObjectNotificationWriteBuffer : public AsyncWriteBuffer { notification_msg.reset(message); size = message->size(); AsyncWriteBuffer::handler_ = - [](const asio::error_code& status) -> AsyncWriteCallbackCode { + [](const error_code& status) -> AsyncWriteCallbackCode { auto errno_ = status.value(); if (!errno_) { return AsyncWriteCallbackCode::OK; diff --git a/cpp/src/plasma/io/connection.h b/cpp/src/plasma/io/connection.h index 716bc6611e5..1bf232558d0 100644 --- a/cpp/src/plasma/io/connection.h +++ b/cpp/src/plasma/io/connection.h @@ -37,7 +37,7 @@ using arrow::Status; using PlasmaConnection = Connection; -Status asio_to_arrow_status(const std::error_code& ec); +Status asio_to_arrow_status(const error_code& ec); /// A generic type representing a client connection to a server. This typename /// can be used to write messages synchronously to the server. @@ -145,11 +145,11 @@ class ClientConnection : public ServerConnection { private: /// Process the message header from the client. /// \param ec The returned error code. - void ProcessMessageHeader(const std::error_code& ec); + void ProcessMessageHeader(const error_code& ec); /// Process the message body from the client. /// \param ec The returned error code. - void ProcessMessageBody(const std::error_code& ec); + void ProcessMessageBody(const error_code& ec); /// Process the message from the client. /// \param type The type of the message. diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index 71159c02d46..34964fec79a 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -116,7 +116,7 @@ struct GetRequest { } void AsyncWait(int64_t timeout_ms, - std::function on_timeout) { + std::function on_timeout) { // Set an expiry time relative to now. timer_.expires_from_now(std::chrono::milliseconds(timeout_ms)); timer_.async_wait(on_timeout); @@ -498,7 +498,7 @@ Status PlasmaStore::ProcessGetRequest(const std::shared_ptr& c } else if (timeout_ms != -1) { // Set a timer that will cause the get request to return to the client. Note // that a timeout of -1 is used to indicate that no timer should be set. - get_req->AsyncWait(timeout_ms, [this, get_req](const asio::error_code& ec) { + get_req->AsyncWait(timeout_ms, [this, get_req](const error_code& ec) { if (ec != asio::error::operation_aborted) { // Timer was not cancelled, take necessary action. ReturnFromGet(get_req); @@ -721,10 +721,10 @@ void PlasmaStore::SubscribeToUpdates(const std::shared_ptr& cl void PlasmaStore::DoAccept() { // TODO(suquark): Use shared_from_this() here ? acceptor_.async_accept(stream_, - [this](const asio::error_code& ec) { HandleAccept(ec); }); + [this](const error_code& ec) { HandleAccept(ec); }); } -void PlasmaStore::HandleAccept(const asio::error_code& error) { +void PlasmaStore::HandleAccept(const error_code& error) { if (!error) { io::MessageHandler message_handler = [this](std::shared_ptr client, int64_t message_type, int64_t length, diff --git a/cpp/src/plasma/store.h b/cpp/src/plasma/store.h index c992a8edf47..af84285a970 100644 --- a/cpp/src/plasma/store.h +++ b/cpp/src/plasma/store.h @@ -194,7 +194,7 @@ class PlasmaStore { /// Accept a client connection. void DoAccept(); /// Handle an accepted client connection. - void HandleAccept(const asio::error_code& error); + void HandleAccept(const error_code& error); Status ProcessClientMessage(const std::shared_ptr& client, int64_t message_type, int64_t message_size, From df0651075dcf0fc304c48dd2319efc061427349c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 14 Jul 2019 22:51:12 -0700 Subject: [PATCH 48/52] fix --- cpp/src/plasma/client.cc | 6 ++++-- cpp/src/plasma/io/basic_connection.cc | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index 2a9c4e6c344..a6c02f7e506 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -918,7 +918,8 @@ Status PlasmaClient::Impl::Subscribe() { if (store_socket_name_.empty()) { ARROW_LOG(FATAL) << "Please connect to the store before subscribing messages."; } - auto stream = io::CreateLocalStream(io_context_, store_socket_name_); + PlasmaStream stream(io_context_); + RETURN_NOT_OK(io::CreateLocalStream(store_socket_name_, &stream)); auto conn = ServerConnection::Create(std::move(stream)); notification_conn_ = std::move(conn); // Tell the Plasma store about the subscription. @@ -966,7 +967,8 @@ Status PlasmaClient::Impl::GetNotification(ObjectID* object_id, int64_t* data_si Status PlasmaClient::Impl::Connect(const std::string& store_socket_name) { std::lock_guard guard(client_mutex_); store_socket_name_ = store_socket_name; - auto stream = io::CreateLocalStream(io_context_, store_socket_name_); + PlasmaStream stream(io_context_); + RETURN_NOT_OK(io::CreateLocalStream(io_context_, store_socket_name_, &stream)); auto conn = ServerConnection::Create(std::move(stream)); store_conn_ = std::move(conn); diff --git a/cpp/src/plasma/io/basic_connection.cc b/cpp/src/plasma/io/basic_connection.cc index 15ae7a05385..7325855ac52 100644 --- a/cpp/src/plasma/io/basic_connection.cc +++ b/cpp/src/plasma/io/basic_connection.cc @@ -44,15 +44,16 @@ error_code UnixDomainSocketConnect(asio::local::stream_protocol::socket& socket, return ec; } -PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& name) { +Status CreateLocalStream(const std::string& name, PlasmaStream* result) { // TODO(suquark): May be use "kNumConnectAttempts" and "kConnectTimeoutMs"? constexpr int num_retries = 50; constexpr int timeout_ms = 100; - ARROW_CHECK(!name.empty()); + if (name.empty()) { + return Status::Invalid("Cannot connect to empty socket name"); + } #ifndef _WIN32 - asio::basic_stream_socket socket(io_context); for (int i = 0; i < num_retries; i++) { - error_code ec = UnixDomainSocketConnect(socket, name); + error_code ec = UnixDomainSocketConnect(*result, name); if (!ec) { break; } @@ -63,7 +64,7 @@ PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& << ")"; } } - return socket; + return Status::OK(); #else // For windows: https://stackoverflow.com/questions/1236460/c-using-windows-named-pipes #error "Windows has not been supported." From fb0e41134071e7c2af5a472b74d57c5a5489cc9e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 14 Jul 2019 23:02:02 -0700 Subject: [PATCH 49/52] update --- cpp/src/plasma/client.cc | 4 ++-- cpp/src/plasma/io/basic_connection.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index a6c02f7e506..457728062f5 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -918,7 +918,7 @@ Status PlasmaClient::Impl::Subscribe() { if (store_socket_name_.empty()) { ARROW_LOG(FATAL) << "Please connect to the store before subscribing messages."; } - PlasmaStream stream(io_context_); + io::PlasmaStream stream(io_context_); RETURN_NOT_OK(io::CreateLocalStream(store_socket_name_, &stream)); auto conn = ServerConnection::Create(std::move(stream)); notification_conn_ = std::move(conn); @@ -967,7 +967,7 @@ Status PlasmaClient::Impl::GetNotification(ObjectID* object_id, int64_t* data_si Status PlasmaClient::Impl::Connect(const std::string& store_socket_name) { std::lock_guard guard(client_mutex_); store_socket_name_ = store_socket_name; - PlasmaStream stream(io_context_); + io::PlasmaStream stream(io_context_); RETURN_NOT_OK(io::CreateLocalStream(io_context_, store_socket_name_, &stream)); auto conn = ServerConnection::Create(std::move(stream)); store_conn_ = std::move(conn); diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index b1fc00abf51..998e0a8ef5c 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -54,7 +54,7 @@ using PlasmaAcceptor = asio::local::stream_protocol::acceptor; PlasmaAcceptor CreateLocalAcceptor(asio::io_context& io_context, const std::string& name); /// Create a local stream depends on the platform. -PlasmaStream CreateLocalStream(asio::io_context& io_context, const std::string& name); +Status CreateLocalStream(const std::string& name, PlasmaStream* result); /// A message that is queued for writing asynchronously. struct AsyncWriteBuffer { From 57bfd6471419fd7ac05d0da277889e0184a7f7ee Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 14 Jul 2019 23:06:04 -0700 Subject: [PATCH 50/52] fix --- cpp/src/plasma/io/basic_connection.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index 998e0a8ef5c..b9a3e16d8cb 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -32,6 +32,8 @@ #include #include +#include "arrow/status.h" + namespace asio = boost::asio; using error_code = boost::system::error_code; From 4cad6028a9c044de7d6591e59868ddb56dfaae4a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 14 Jul 2019 23:08:56 -0700 Subject: [PATCH 51/52] update --- cpp/src/plasma/io/basic_connection.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/plasma/io/basic_connection.h b/cpp/src/plasma/io/basic_connection.h index b9a3e16d8cb..0bda2eacfc2 100644 --- a/cpp/src/plasma/io/basic_connection.h +++ b/cpp/src/plasma/io/basic_connection.h @@ -37,6 +37,7 @@ namespace asio = boost::asio; using error_code = boost::system::error_code; +using arrow::Status; namespace plasma { namespace io { From 11871a2ca133db02186b5cd270af2997baecd383 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 14 Jul 2019 23:11:06 -0700 Subject: [PATCH 52/52] update --- cpp/src/plasma/client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index 457728062f5..6015583d4dd 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -968,7 +968,7 @@ Status PlasmaClient::Impl::Connect(const std::string& store_socket_name) { std::lock_guard guard(client_mutex_); store_socket_name_ = store_socket_name; io::PlasmaStream stream(io_context_); - RETURN_NOT_OK(io::CreateLocalStream(io_context_, store_socket_name_, &stream)); + RETURN_NOT_OK(io::CreateLocalStream(store_socket_name_, &stream)); auto conn = ServerConnection::Create(std::move(stream)); store_conn_ = std::move(conn);