From 659c150583f71ab5f52bd7cb503084a815c17b75 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 11 Jul 2022 14:14:44 +0200 Subject: [PATCH 1/2] ARROW-17037: [C++] Split utf8.h to avoid exposing xsimd dependency to the outside --- cpp/src/arrow/array/array_binary.cc | 1 + .../compute/kernels/scalar_cast_string.cc | 6 +- .../compute/kernels/scalar_string_utf8.cc | 2 +- cpp/src/arrow/csv/converter.cc | 4 +- cpp/src/arrow/csv/reader.cc | 2 +- cpp/src/arrow/util/utf8.cc | 10 +- cpp/src/arrow/util/utf8.h | 526 +--------------- cpp/src/arrow/util/utf8_internal.h | 560 ++++++++++++++++++ cpp/src/arrow/util/utf8_util_benchmark.cc | 6 +- cpp/src/arrow/util/utf8_util_test.cc | 2 +- cpp/src/gandiva/gdv_string_function_stubs.cc | 2 +- 11 files changed, 584 insertions(+), 537 deletions(-) create mode 100644 cpp/src/arrow/util/utf8_internal.h diff --git a/cpp/src/arrow/array/array_binary.cc b/cpp/src/arrow/array/array_binary.cc index 9466b5a48f9..20d9ddcef5c 100644 --- a/cpp/src/arrow/array/array_binary.cc +++ b/cpp/src/arrow/array/array_binary.cc @@ -26,6 +26,7 @@ #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/utf8.h" namespace arrow { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index dab91ac0346..6b21a532392 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -27,14 +27,14 @@ #include "arrow/util/formatting.h" #include "arrow/util/int_util.h" #include "arrow/util/optional.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" #include "arrow/visit_data_inline.h" namespace arrow { using internal::StringFormatter; using util::InitializeUTF8; -using util::ValidateUTF8; +using util::ValidateUTF8Inline; namespace compute { namespace internal { @@ -197,7 +197,7 @@ struct Utf8Validator { Status VisitNull() { return Status::OK(); } Status VisitValue(util::string_view str) { - if (ARROW_PREDICT_FALSE(!ValidateUTF8(str))) { + if (ARROW_PREDICT_FALSE(!ValidateUTF8Inline(str))) { return Status::Invalid("Invalid UTF8 payload"); } return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index 02585ed34ac..4b3191c825d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -24,7 +24,7 @@ #endif #include "arrow/compute/kernels/scalar_string_internal.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/csv/converter.cc b/cpp/src/arrow/csv/converter.cc index b08502e5c66..c07eddffd43 100644 --- a/cpp/src/arrow/csv/converter.cc +++ b/cpp/src/arrow/csv/converter.cc @@ -37,7 +37,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/trie.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" #include "arrow/util/value_parsing.h" // IWYU pragma: keep namespace arrow { @@ -176,7 +176,7 @@ struct BinaryValueDecoder : public ValueDecoder { } Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { - if (CheckUTF8 && ARROW_PREDICT_FALSE(!util::ValidateUTF8(data, size))) { + if (CheckUTF8 && ARROW_PREDICT_FALSE(!util::ValidateUTF8Inline(data, size))) { return Status::Invalid("CSV conversion error to ", type_->ToString(), ": invalid UTF8 data"); } diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index 376651ca408..ba754399b75 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -49,7 +49,7 @@ #include "arrow/util/optional.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" #include "arrow/util/vector.h" namespace arrow { diff --git a/cpp/src/arrow/util/utf8.cc b/cpp/src/arrow/util/utf8.cc index 11394d2e64c..e589e1763e6 100644 --- a/cpp/src/arrow/util/utf8.cc +++ b/cpp/src/arrow/util/utf8.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/util/utf8.h" + #include #include #include @@ -23,7 +25,7 @@ #include "arrow/result.h" #include "arrow/util/logging.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" #include "arrow/vendored/utfcpp/checked.h" // Can be defined by utfcpp @@ -90,6 +92,12 @@ void InitializeUTF8() { std::call_once(utf8_initialized, internal::InitializeLargeTable); } +bool ValidateUTF8(const uint8_t* data, int64_t size) { + return ValidateUTF8Inline(data, size); +} + +bool ValidateUTF8(const util::string_view& str) { return ValidateUTF8Inline(str); } + static const uint8_t kBOM[] = {0xEF, 0xBB, 0xBF}; Result SkipUTF8BOM(const uint8_t* data, int64_t size) { diff --git a/cpp/src/arrow/util/utf8.h b/cpp/src/arrow/util/utf8.h index 655108b5a36..eab207d2a02 100644 --- a/cpp/src/arrow/util/utf8.h +++ b/cpp/src/arrow/util/utf8.h @@ -17,21 +17,13 @@ #pragma once -#include #include #include -#include #include -#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) -#include -#endif - #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -#include "arrow/util/simd.h" #include "arrow/util/string_view.h" -#include "arrow/util/ubsan.h" #include "arrow/util/visibility.h" namespace arrow { @@ -44,243 +36,12 @@ ARROW_EXPORT Result UTF8ToWideString(const std::string& source); // Similarly, convert a wstring to a UTF8 string. ARROW_EXPORT Result WideStringToUTF8(const std::wstring& source); -namespace internal { - -// Copyright (c) 2008-2010 Bjoern Hoehrmann -// See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details. - -// A compact state table allowing UTF8 decoding using two dependent -// lookups per byte. The first lookup determines the character class -// and the second lookup reads the next state. -// In this table states are multiples of 12. -ARROW_EXPORT extern const uint8_t utf8_small_table[256 + 9 * 12]; - -// Success / reject states when looked up in the small table -static constexpr uint8_t kUTF8DecodeAccept = 0; -static constexpr uint8_t kUTF8DecodeReject = 12; - -// An expanded state table allowing transitions using a single lookup -// at the expense of a larger memory footprint (but on non-random data, -// not all the table will end up accessed and cached). -// In this table states are multiples of 256. -ARROW_EXPORT extern uint16_t utf8_large_table[9 * 256]; - -ARROW_EXPORT extern const uint8_t utf8_byte_size_table[16]; - -// Success / reject states when looked up in the large table -static constexpr uint16_t kUTF8ValidateAccept = 0; -static constexpr uint16_t kUTF8ValidateReject = 256; - -static inline uint8_t DecodeOneUTF8Byte(uint8_t byte, uint8_t state, uint32_t* codep) { - uint8_t type = utf8_small_table[byte]; - - *codep = (state != kUTF8DecodeAccept) ? (byte & 0x3fu) | (*codep << 6) - : (0xff >> type) & (byte); - - state = utf8_small_table[256 + state + type]; - return state; -} - -static inline uint16_t ValidateOneUTF8Byte(uint8_t byte, uint16_t state) { - return utf8_large_table[state + byte]; -} - -ARROW_EXPORT void CheckUTF8Initialized(); - -} // namespace internal - // This function needs to be called before doing UTF8 validation. ARROW_EXPORT void InitializeUTF8(); -static inline bool ValidateUTF8(const uint8_t* data, int64_t size) { - static constexpr uint64_t high_bits_64 = 0x8080808080808080ULL; - static constexpr uint32_t high_bits_32 = 0x80808080UL; - static constexpr uint16_t high_bits_16 = 0x8080U; - static constexpr uint8_t high_bits_8 = 0x80U; - -#ifndef NDEBUG - internal::CheckUTF8Initialized(); -#endif - - while (size >= 8) { - // XXX This is doing an unaligned access. Contemporary architectures - // (x86-64, AArch64, PPC64) support it natively and often have good - // performance nevertheless. - uint64_t mask64 = SafeLoadAs(data); - if (ARROW_PREDICT_TRUE((mask64 & high_bits_64) == 0)) { - // 8 bytes of pure ASCII, move forward - size -= 8; - data += 8; - continue; - } - // Non-ASCII run detected. - // We process at least 4 bytes, to avoid too many spurious 64-bit reads - // in case the non-ASCII bytes are at the end of the tested 64-bit word. - // We also only check for rejection at the end since that state is stable - // (once in reject state, we always remain in reject state). - // It is guaranteed that size >= 8 when arriving here, which allows - // us to avoid size checks. - uint16_t state = internal::kUTF8ValidateAccept; - // Byte 0 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - // Byte 1 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - // Byte 2 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - // Byte 3 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - // Byte 4 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - if (state == internal::kUTF8ValidateAccept) { - continue; // Got full char, switch back to ASCII detection - } - // Byte 5 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - if (state == internal::kUTF8ValidateAccept) { - continue; // Got full char, switch back to ASCII detection - } - // Byte 6 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - if (state == internal::kUTF8ValidateAccept) { - continue; // Got full char, switch back to ASCII detection - } - // Byte 7 - state = internal::ValidateOneUTF8Byte(*data++, state); - --size; - if (state == internal::kUTF8ValidateAccept) { - continue; // Got full char, switch back to ASCII detection - } - // kUTF8ValidateAccept not reached along 4 transitions has to mean a rejection - assert(state == internal::kUTF8ValidateReject); - return false; - } - - // Check if string tail is full ASCII (common case, fast) - if (size >= 4) { - uint32_t tail_mask = SafeLoadAs(data + size - 4); - uint32_t head_mask = SafeLoadAs(data); - if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_32) == 0)) { - return true; - } - } else if (size >= 2) { - uint16_t tail_mask = SafeLoadAs(data + size - 2); - uint16_t head_mask = SafeLoadAs(data); - if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_16) == 0)) { - return true; - } - } else if (size == 1) { - if (ARROW_PREDICT_TRUE((*data & high_bits_8) == 0)) { - return true; - } - } else { - /* size == 0 */ - return true; - } - - // Fall back to UTF8 validation of tail string. - // Note the state table is designed so that, once in the reject state, - // we remain in that state until the end. So we needn't check for - // rejection at each char (we don't gain much by short-circuiting here). - uint16_t state = internal::kUTF8ValidateAccept; - switch (size) { - case 7: - state = internal::ValidateOneUTF8Byte(data[size - 7], state); - case 6: - state = internal::ValidateOneUTF8Byte(data[size - 6], state); - case 5: - state = internal::ValidateOneUTF8Byte(data[size - 5], state); - case 4: - state = internal::ValidateOneUTF8Byte(data[size - 4], state); - case 3: - state = internal::ValidateOneUTF8Byte(data[size - 3], state); - case 2: - state = internal::ValidateOneUTF8Byte(data[size - 2], state); - case 1: - state = internal::ValidateOneUTF8Byte(data[size - 1], state); - default: - break; - } - return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept); -} - -static inline bool ValidateUTF8(const util::string_view& str) { - const uint8_t* data = reinterpret_cast(str.data()); - const size_t length = str.size(); - - return ValidateUTF8(data, length); -} - -static inline bool ValidateAsciiSw(const uint8_t* data, int64_t len) { - uint8_t orall = 0; - - if (len >= 8) { - uint64_t or8 = 0; - - do { - or8 |= SafeLoadAs(data); - data += 8; - len -= 8; - } while (len >= 8); - - orall = !(or8 & 0x8080808080808080ULL) - 1; - } - - while (len--) { - orall |= *data++; - } - - return orall < 0x80U; -} +ARROW_EXPORT bool ValidateUTF8(const uint8_t* data, int64_t size); -#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) -static inline bool ValidateAsciiSimd(const uint8_t* data, int64_t len) { - using simd_batch = xsimd::make_sized_batch_t; - - if (len >= 32) { - const simd_batch zero(static_cast(0)); - const uint8_t* data2 = data + 16; - simd_batch or1 = zero, or2 = zero; - - while (len >= 32) { - or1 |= simd_batch::load_unaligned(reinterpret_cast(data)); - or2 |= simd_batch::load_unaligned(reinterpret_cast(data2)); - data += 32; - data2 += 32; - len -= 32; - } - - // To test for upper bit in all bytes, test whether any of them is negative - or1 |= or2; - if (xsimd::any(or1 < zero)) { - return false; - } - } - - return ValidateAsciiSw(data, len); -} -#endif // ARROW_HAVE_NEON || ARROW_HAVE_SSE4_2 - -static inline bool ValidateAscii(const uint8_t* data, int64_t len) { -#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) - return ValidateAsciiSimd(data, len); -#else - return ValidateAsciiSw(data, len); -#endif -} - -static inline bool ValidateAscii(const util::string_view& str) { - const uint8_t* data = reinterpret_cast(str.data()); - const size_t length = str.size(); - - return ValidateAscii(data, length); -} +ARROW_EXPORT bool ValidateUTF8(const util::string_view& str); // Skip UTF8 byte order mark, if any. ARROW_EXPORT @@ -288,288 +49,5 @@ Result SkipUTF8BOM(const uint8_t* data, int64_t size); static constexpr uint32_t kMaxUnicodeCodepoint = 0x110000; -// size of a valid UTF8 can be determined by looking at leading 4 bits of BYTE1 -// utf8_byte_size_table[0..7] --> pure ascii chars --> 1B length -// utf8_byte_size_table[8..11] --> internal bytes --> 1B length -// utf8_byte_size_table[12,13] --> 2B long UTF8 chars -// utf8_byte_size_table[14] --> 3B long UTF8 chars -// utf8_byte_size_table[15] --> 4B long UTF8 chars -// NOTE: Results for invalid/ malformed utf-8 sequences are undefined. -// ex: \xFF... returns 4B -static inline uint8_t ValidUtf8CodepointByteSize(const uint8_t* codeunit) { - return internal::utf8_byte_size_table[*codeunit >> 4]; -} - -static inline bool Utf8IsContinuation(const uint8_t codeunit) { - return (codeunit & 0xC0) == 0x80; // upper two bits should be 10 -} - -static inline bool Utf8Is2ByteStart(const uint8_t codeunit) { - return (codeunit & 0xE0) == 0xC0; // upper three bits should be 110 -} - -static inline bool Utf8Is3ByteStart(const uint8_t codeunit) { - return (codeunit & 0xF0) == 0xE0; // upper four bits should be 1110 -} - -static inline bool Utf8Is4ByteStart(const uint8_t codeunit) { - return (codeunit & 0xF8) == 0xF0; // upper five bits should be 11110 -} - -/// Return the number of bytes required to UTF8-encode the given codepoint -static inline int32_t UTF8EncodedLength(uint32_t codepoint) { - if (codepoint < 0x80) { - return 1; - } else if (codepoint < 0x800) { - return 2; - } else if (codepoint < 0x10000) { - return 3; - } else { - return 4; - } -} - -static inline uint8_t* UTF8Encode(uint8_t* str, uint32_t codepoint) { - if (codepoint < 0x80) { - *str++ = codepoint; - } else if (codepoint < 0x800) { - *str++ = 0xC0 + (codepoint >> 6); - *str++ = 0x80 + (codepoint & 0x3F); - } else if (codepoint < 0x10000) { - *str++ = 0xE0 + (codepoint >> 12); - *str++ = 0x80 + ((codepoint >> 6) & 0x3F); - *str++ = 0x80 + (codepoint & 0x3F); - } else { - // Assume proper codepoints are always passed - assert(codepoint < kMaxUnicodeCodepoint); - *str++ = 0xF0 + (codepoint >> 18); - *str++ = 0x80 + ((codepoint >> 12) & 0x3F); - *str++ = 0x80 + ((codepoint >> 6) & 0x3F); - *str++ = 0x80 + (codepoint & 0x3F); - } - return str; -} - -static inline bool UTF8Decode(const uint8_t** data, uint32_t* codepoint) { - const uint8_t* str = *data; - if (*str < 0x80) { // ascii - *codepoint = *str++; - } else if (ARROW_PREDICT_FALSE(*str < 0xC0)) { // invalid non-ascii char - return false; - } else if (*str < 0xE0) { - uint8_t code_unit_1 = (*str++) & 0x1F; // take last 5 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits - *codepoint = (code_unit_1 << 6) + code_unit_2; - } else if (*str < 0xF0) { - uint8_t code_unit_1 = (*str++) & 0x0F; // take last 4 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits - *codepoint = (code_unit_1 << 12) + (code_unit_2 << 6) + code_unit_3; - } else if (*str < 0xF8) { - uint8_t code_unit_1 = (*str++) & 0x07; // take last 3 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_4 = (*str++) & 0x3F; // take last 6 bits - *codepoint = - (code_unit_1 << 18) + (code_unit_2 << 12) + (code_unit_3 << 6) + code_unit_4; - } else { // invalid non-ascii char - return false; - } - *data = str; - return true; -} - -static inline bool UTF8DecodeReverse(const uint8_t** data, uint32_t* codepoint) { - const uint8_t* str = *data; - if (*str < 0x80) { // ascii - *codepoint = *str--; - } else { - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_N = (*str--) & 0x3F; // take last 6 bits - if (Utf8Is2ByteStart(*str)) { - uint8_t code_unit_1 = (*str--) & 0x1F; // take last 5 bits - *codepoint = (code_unit_1 << 6) + code_unit_N; - } else { - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_Nmin1 = (*str--) & 0x3F; // take last 6 bits - if (Utf8Is3ByteStart(*str)) { - uint8_t code_unit_1 = (*str--) & 0x0F; // take last 4 bits - *codepoint = (code_unit_1 << 12) + (code_unit_Nmin1 << 6) + code_unit_N; - } else { - if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { - return false; - } - uint8_t code_unit_Nmin2 = (*str--) & 0x3F; // take last 6 bits - if (ARROW_PREDICT_TRUE(Utf8Is4ByteStart(*str))) { - uint8_t code_unit_1 = (*str--) & 0x07; // take last 3 bits - *codepoint = (code_unit_1 << 18) + (code_unit_Nmin2 << 12) + - (code_unit_Nmin1 << 6) + code_unit_N; - } else { - return false; - } - } - } - } - *data = str; - return true; -} - -template -static inline bool UTF8Transform(const uint8_t* first, const uint8_t* last, - uint8_t** destination, UnaryOperation&& unary_op) { - const uint8_t* i = first; - uint8_t* out = *destination; - while (i < last) { - uint32_t codepoint = 0; - if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { - return false; - } - out = UTF8Encode(out, unary_op(codepoint)); - } - *destination = out; - return true; -} - -template -static inline bool UTF8FindIf(const uint8_t* first, const uint8_t* last, - Predicate&& predicate, const uint8_t** position) { - const uint8_t* i = first; - while (i < last) { - uint32_t codepoint = 0; - const uint8_t* current = i; - if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { - return false; - } - if (predicate(codepoint)) { - *position = current; - return true; - } - } - *position = last; - return true; -} - -// Same semantics as std::find_if using reverse iterators with the return value -// having the same semantics as std::reverse_iterator<..>.base() -// A reverse iterator physically points to the next address, e.g.: -// &*reverse_iterator(i) == &*(i + 1) -template -static inline bool UTF8FindIfReverse(const uint8_t* first, const uint8_t* last, - Predicate&& predicate, const uint8_t** position) { - // converts to a normal point - const uint8_t* i = last - 1; - while (i >= first) { - uint32_t codepoint = 0; - const uint8_t* current = i; - if (ARROW_PREDICT_FALSE(!UTF8DecodeReverse(&i, &codepoint))) { - return false; - } - if (predicate(codepoint)) { - // converts normal pointer to 'reverse iterator semantics'. - *position = current + 1; - return true; - } - } - // similar to how an end pointer point to 1 beyond the last, reverse iterators point - // to the 'first' pointer to indicate out of range. - *position = first; - return true; -} - -static inline bool UTF8AdvanceCodepoints(const uint8_t* first, const uint8_t* last, - const uint8_t** destination, int64_t n) { - return UTF8FindIf( - first, last, - [&](uint32_t codepoint) { - bool done = n == 0; - n--; - return done; - }, - destination); -} - -static inline bool UTF8AdvanceCodepointsReverse(const uint8_t* first, const uint8_t* last, - const uint8_t** destination, int64_t n) { - return UTF8FindIfReverse( - first, last, - [&](uint32_t codepoint) { - bool done = n == 0; - n--; - return done; - }, - destination); -} - -template -static inline bool UTF8ForEach(const uint8_t* first, const uint8_t* last, - UnaryFunction&& f) { - const uint8_t* i = first; - while (i < last) { - uint32_t codepoint = 0; - if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { - return false; - } - f(codepoint); - } - return true; -} - -template -static inline bool UTF8ForEach(const std::string& s, UnaryFunction&& f) { - return UTF8ForEach(reinterpret_cast(s.data()), - reinterpret_cast(s.data() + s.length()), - std::forward(f)); -} - -template -static inline bool UTF8AllOf(const uint8_t* first, const uint8_t* last, bool* result, - UnaryPredicate&& predicate) { - const uint8_t* i = first; - while (i < last) { - uint32_t codepoint = 0; - if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { - return false; - } - - if (!predicate(codepoint)) { - *result = false; - return true; - } - } - *result = true; - return true; -} - -/// Count the number of codepoints in the given string (assuming it is valid UTF8). -static inline int64_t UTF8Length(const uint8_t* first, const uint8_t* last) { - int64_t length = 0; - while (first != last) { - length += ((*first++ & 0xc0) != 0x80); - } - return length; -} - } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/utf8_internal.h b/cpp/src/arrow/util/utf8_internal.h new file mode 100644 index 00000000000..9d2954e9d1c --- /dev/null +++ b/cpp/src/arrow/util/utf8_internal.h @@ -0,0 +1,560 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) +#include +#endif + +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/simd.h" +#include "arrow/util/string_view.h" +#include "arrow/util/ubsan.h" +#include "arrow/util/utf8.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace util { + +namespace internal { + +// Copyright (c) 2008-2010 Bjoern Hoehrmann +// See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details. + +// A compact state table allowing UTF8 decoding using two dependent +// lookups per byte. The first lookup determines the character class +// and the second lookup reads the next state. +// In this table states are multiples of 12. +ARROW_EXPORT extern const uint8_t utf8_small_table[256 + 9 * 12]; + +// Success / reject states when looked up in the small table +static constexpr uint8_t kUTF8DecodeAccept = 0; +static constexpr uint8_t kUTF8DecodeReject = 12; + +// An expanded state table allowing transitions using a single lookup +// at the expense of a larger memory footprint (but on non-random data, +// not all the table will end up accessed and cached). +// In this table states are multiples of 256. +ARROW_EXPORT extern uint16_t utf8_large_table[9 * 256]; + +ARROW_EXPORT extern const uint8_t utf8_byte_size_table[16]; + +// Success / reject states when looked up in the large table +static constexpr uint16_t kUTF8ValidateAccept = 0; +static constexpr uint16_t kUTF8ValidateReject = 256; + +static inline uint8_t DecodeOneUTF8Byte(uint8_t byte, uint8_t state, uint32_t* codep) { + uint8_t type = utf8_small_table[byte]; + + *codep = (state != kUTF8DecodeAccept) ? (byte & 0x3fu) | (*codep << 6) + : (0xff >> type) & (byte); + + state = utf8_small_table[256 + state + type]; + return state; +} + +static inline uint16_t ValidateOneUTF8Byte(uint8_t byte, uint16_t state) { + return utf8_large_table[state + byte]; +} + +ARROW_EXPORT void CheckUTF8Initialized(); + +} // namespace internal + +static inline bool ValidateUTF8Inline(const uint8_t* data, int64_t size) { + static constexpr uint64_t high_bits_64 = 0x8080808080808080ULL; + static constexpr uint32_t high_bits_32 = 0x80808080UL; + static constexpr uint16_t high_bits_16 = 0x8080U; + static constexpr uint8_t high_bits_8 = 0x80U; + +#ifndef NDEBUG + internal::CheckUTF8Initialized(); +#endif + + while (size >= 8) { + // XXX This is doing an unaligned access. Contemporary architectures + // (x86-64, AArch64, PPC64) support it natively and often have good + // performance nevertheless. + uint64_t mask64 = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE((mask64 & high_bits_64) == 0)) { + // 8 bytes of pure ASCII, move forward + size -= 8; + data += 8; + continue; + } + // Non-ASCII run detected. + // We process at least 4 bytes, to avoid too many spurious 64-bit reads + // in case the non-ASCII bytes are at the end of the tested 64-bit word. + // We also only check for rejection at the end since that state is stable + // (once in reject state, we always remain in reject state). + // It is guaranteed that size >= 8 when arriving here, which allows + // us to avoid size checks. + uint16_t state = internal::kUTF8ValidateAccept; + // Byte 0 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 1 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 2 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 3 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 4 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 5 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 6 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 7 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // kUTF8ValidateAccept not reached along 4 transitions has to mean a rejection + assert(state == internal::kUTF8ValidateReject); + return false; + } + + // Check if string tail is full ASCII (common case, fast) + if (size >= 4) { + uint32_t tail_mask = SafeLoadAs(data + size - 4); + uint32_t head_mask = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_32) == 0)) { + return true; + } + } else if (size >= 2) { + uint16_t tail_mask = SafeLoadAs(data + size - 2); + uint16_t head_mask = SafeLoadAs(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_16) == 0)) { + return true; + } + } else if (size == 1) { + if (ARROW_PREDICT_TRUE((*data & high_bits_8) == 0)) { + return true; + } + } else { + /* size == 0 */ + return true; + } + + // Fall back to UTF8 validation of tail string. + // Note the state table is designed so that, once in the reject state, + // we remain in that state until the end. So we needn't check for + // rejection at each char (we don't gain much by short-circuiting here). + uint16_t state = internal::kUTF8ValidateAccept; + switch (size) { + case 7: + state = internal::ValidateOneUTF8Byte(data[size - 7], state); + case 6: + state = internal::ValidateOneUTF8Byte(data[size - 6], state); + case 5: + state = internal::ValidateOneUTF8Byte(data[size - 5], state); + case 4: + state = internal::ValidateOneUTF8Byte(data[size - 4], state); + case 3: + state = internal::ValidateOneUTF8Byte(data[size - 3], state); + case 2: + state = internal::ValidateOneUTF8Byte(data[size - 2], state); + case 1: + state = internal::ValidateOneUTF8Byte(data[size - 1], state); + default: + break; + } + return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept); +} + +static inline bool ValidateUTF8Inline(const util::string_view& str) { + const uint8_t* data = reinterpret_cast(str.data()); + const size_t length = str.size(); + + return ValidateUTF8Inline(data, length); +} + +static inline bool ValidateAsciiSw(const uint8_t* data, int64_t len) { + uint8_t orall = 0; + + if (len >= 8) { + uint64_t or8 = 0; + + do { + or8 |= SafeLoadAs(data); + data += 8; + len -= 8; + } while (len >= 8); + + orall = !(or8 & 0x8080808080808080ULL) - 1; + } + + while (len--) { + orall |= *data++; + } + + return orall < 0x80U; +} + +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) +static inline bool ValidateAsciiSimd(const uint8_t* data, int64_t len) { + using simd_batch = xsimd::make_sized_batch_t; + + if (len >= 32) { + const simd_batch zero(static_cast(0)); + const uint8_t* data2 = data + 16; + simd_batch or1 = zero, or2 = zero; + + while (len >= 32) { + or1 |= simd_batch::load_unaligned(reinterpret_cast(data)); + or2 |= simd_batch::load_unaligned(reinterpret_cast(data2)); + data += 32; + data2 += 32; + len -= 32; + } + + // To test for upper bit in all bytes, test whether any of them is negative + or1 |= or2; + if (xsimd::any(or1 < zero)) { + return false; + } + } + + return ValidateAsciiSw(data, len); +} +#endif // ARROW_HAVE_NEON || ARROW_HAVE_SSE4_2 + +static inline bool ValidateAscii(const uint8_t* data, int64_t len) { +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) + return ValidateAsciiSimd(data, len); +#else + return ValidateAsciiSw(data, len); +#endif +} + +static inline bool ValidateAscii(const util::string_view& str) { + const uint8_t* data = reinterpret_cast(str.data()); + const size_t length = str.size(); + + return ValidateAscii(data, length); +} + +// size of a valid UTF8 can be determined by looking at leading 4 bits of BYTE1 +// utf8_byte_size_table[0..7] --> pure ascii chars --> 1B length +// utf8_byte_size_table[8..11] --> internal bytes --> 1B length +// utf8_byte_size_table[12,13] --> 2B long UTF8 chars +// utf8_byte_size_table[14] --> 3B long UTF8 chars +// utf8_byte_size_table[15] --> 4B long UTF8 chars +// NOTE: Results for invalid/ malformed utf-8 sequences are undefined. +// ex: \xFF... returns 4B +static inline uint8_t ValidUtf8CodepointByteSize(const uint8_t* codeunit) { + return internal::utf8_byte_size_table[*codeunit >> 4]; +} + +static inline bool Utf8IsContinuation(const uint8_t codeunit) { + return (codeunit & 0xC0) == 0x80; // upper two bits should be 10 +} + +static inline bool Utf8Is2ByteStart(const uint8_t codeunit) { + return (codeunit & 0xE0) == 0xC0; // upper three bits should be 110 +} + +static inline bool Utf8Is3ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF0) == 0xE0; // upper four bits should be 1110 +} + +static inline bool Utf8Is4ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF8) == 0xF0; // upper five bits should be 11110 +} + +/// Return the number of bytes required to UTF8-encode the given codepoint +static inline int32_t UTF8EncodedLength(uint32_t codepoint) { + if (codepoint < 0x80) { + return 1; + } else if (codepoint < 0x800) { + return 2; + } else if (codepoint < 0x10000) { + return 3; + } else { + return 4; + } +} + +static inline uint8_t* UTF8Encode(uint8_t* str, uint32_t codepoint) { + if (codepoint < 0x80) { + *str++ = codepoint; + } else if (codepoint < 0x800) { + *str++ = 0xC0 + (codepoint >> 6); + *str++ = 0x80 + (codepoint & 0x3F); + } else if (codepoint < 0x10000) { + *str++ = 0xE0 + (codepoint >> 12); + *str++ = 0x80 + ((codepoint >> 6) & 0x3F); + *str++ = 0x80 + (codepoint & 0x3F); + } else { + // Assume proper codepoints are always passed + assert(codepoint < kMaxUnicodeCodepoint); + *str++ = 0xF0 + (codepoint >> 18); + *str++ = 0x80 + ((codepoint >> 12) & 0x3F); + *str++ = 0x80 + ((codepoint >> 6) & 0x3F); + *str++ = 0x80 + (codepoint & 0x3F); + } + return str; +} + +static inline bool UTF8Decode(const uint8_t** data, uint32_t* codepoint) { + const uint8_t* str = *data; + if (*str < 0x80) { // ascii + *codepoint = *str++; + } else if (ARROW_PREDICT_FALSE(*str < 0xC0)) { // invalid non-ascii char + return false; + } else if (*str < 0xE0) { + uint8_t code_unit_1 = (*str++) & 0x1F; // take last 5 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + *codepoint = (code_unit_1 << 6) + code_unit_2; + } else if (*str < 0xF0) { + uint8_t code_unit_1 = (*str++) & 0x0F; // take last 4 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits + *codepoint = (code_unit_1 << 12) + (code_unit_2 << 6) + code_unit_3; + } else if (*str < 0xF8) { + uint8_t code_unit_1 = (*str++) & 0x07; // take last 3 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_4 = (*str++) & 0x3F; // take last 6 bits + *codepoint = + (code_unit_1 << 18) + (code_unit_2 << 12) + (code_unit_3 << 6) + code_unit_4; + } else { // invalid non-ascii char + return false; + } + *data = str; + return true; +} + +static inline bool UTF8DecodeReverse(const uint8_t** data, uint32_t* codepoint) { + const uint8_t* str = *data; + if (*str < 0x80) { // ascii + *codepoint = *str--; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_N = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is2ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x1F; // take last 5 bits + *codepoint = (code_unit_1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin1 = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is3ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x0F; // take last 4 bits + *codepoint = (code_unit_1 << 12) + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin2 = (*str--) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_TRUE(Utf8Is4ByteStart(*str))) { + uint8_t code_unit_1 = (*str--) & 0x07; // take last 3 bits + *codepoint = (code_unit_1 << 18) + (code_unit_Nmin2 << 12) + + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + return false; + } + } + } + } + *data = str; + return true; +} + +template +static inline bool UTF8Transform(const uint8_t* first, const uint8_t* last, + uint8_t** destination, UnaryOperation&& unary_op) { + const uint8_t* i = first; + uint8_t* out = *destination; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + out = UTF8Encode(out, unary_op(codepoint)); + } + *destination = out; + return true; +} + +template +static inline bool UTF8FindIf(const uint8_t* first, const uint8_t* last, + Predicate&& predicate, const uint8_t** position) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + const uint8_t* current = i; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + if (predicate(codepoint)) { + *position = current; + return true; + } + } + *position = last; + return true; +} + +// Same semantics as std::find_if using reverse iterators with the return value +// having the same semantics as std::reverse_iterator<..>.base() +// A reverse iterator physically points to the next address, e.g.: +// &*reverse_iterator(i) == &*(i + 1) +template +static inline bool UTF8FindIfReverse(const uint8_t* first, const uint8_t* last, + Predicate&& predicate, const uint8_t** position) { + // converts to a normal point + const uint8_t* i = last - 1; + while (i >= first) { + uint32_t codepoint = 0; + const uint8_t* current = i; + if (ARROW_PREDICT_FALSE(!UTF8DecodeReverse(&i, &codepoint))) { + return false; + } + if (predicate(codepoint)) { + // converts normal pointer to 'reverse iterator semantics'. + *position = current + 1; + return true; + } + } + // similar to how an end pointer point to 1 beyond the last, reverse iterators point + // to the 'first' pointer to indicate out of range. + *position = first; + return true; +} + +static inline bool UTF8AdvanceCodepoints(const uint8_t* first, const uint8_t* last, + const uint8_t** destination, int64_t n) { + return UTF8FindIf( + first, last, + [&](uint32_t codepoint) { + bool done = n == 0; + n--; + return done; + }, + destination); +} + +static inline bool UTF8AdvanceCodepointsReverse(const uint8_t* first, const uint8_t* last, + const uint8_t** destination, int64_t n) { + return UTF8FindIfReverse( + first, last, + [&](uint32_t codepoint) { + bool done = n == 0; + n--; + return done; + }, + destination); +} + +template +static inline bool UTF8ForEach(const uint8_t* first, const uint8_t* last, + UnaryFunction&& f) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + f(codepoint); + } + return true; +} + +template +static inline bool UTF8ForEach(const std::string& s, UnaryFunction&& f) { + return UTF8ForEach(reinterpret_cast(s.data()), + reinterpret_cast(s.data() + s.length()), + std::forward(f)); +} + +template +static inline bool UTF8AllOf(const uint8_t* first, const uint8_t* last, bool* result, + UnaryPredicate&& predicate) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + + if (!predicate(codepoint)) { + *result = false; + return true; + } + } + *result = true; + return true; +} + +/// Count the number of codepoints in the given string (assuming it is valid UTF8). +static inline int64_t UTF8Length(const uint8_t* first, const uint8_t* last) { + int64_t length = 0; + while (first != last) { + length += ((*first++ & 0xc0) != 0x80); + } + return length; +} + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/util/utf8_util_benchmark.cc b/cpp/src/arrow/util/utf8_util_benchmark.cc index 2cbaa181d0f..69bb1285a34 100644 --- a/cpp/src/arrow/util/utf8_util_benchmark.cc +++ b/cpp/src/arrow/util/utf8_util_benchmark.cc @@ -22,11 +22,11 @@ #include #include "arrow/testing/gtest_util.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" -// Do not benchmark inlined functions directly inside the benchmark loop +// Do not benchmark potentially inlined functions directly inside the benchmark loop static ARROW_NOINLINE bool ValidateUTF8NoInline(const uint8_t* data, int64_t size) { - return ::arrow::util::ValidateUTF8(data, size); + return ::arrow::util::ValidateUTF8Inline(data, size); } static ARROW_NOINLINE bool ValidateAsciiNoInline(const uint8_t* data, int64_t size) { diff --git a/cpp/src/arrow/util/utf8_util_test.cc b/cpp/src/arrow/util/utf8_util_test.cc index 878d924e4b6..2af5ac954b6 100644 --- a/cpp/src/arrow/util/utf8_util_test.cc +++ b/cpp/src/arrow/util/utf8_util_test.cc @@ -24,7 +24,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/util/string.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" namespace arrow { namespace util { diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index eb3831ccb49..f4bc5f84626 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -26,7 +26,7 @@ #include "arrow/util/double_conversion.h" #include "arrow/util/string_view.h" -#include "arrow/util/utf8.h" +#include "arrow/util/utf8_internal.h" #include "arrow/util/value_parsing.h" #include "gandiva/engine.h" From bfa35aeee9574914f38a95b714960b88aa99d4bc Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 12 Jul 2022 09:36:33 +0200 Subject: [PATCH 2/2] Remove superfluous inclusion --- cpp/src/arrow/array/array_binary.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/array/array_binary.cc b/cpp/src/arrow/array/array_binary.cc index 20d9ddcef5c..9466b5a48f9 100644 --- a/cpp/src/arrow/array/array_binary.cc +++ b/cpp/src/arrow/array/array_binary.cc @@ -26,7 +26,6 @@ #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -#include "arrow/util/utf8.h" namespace arrow {