diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index bee14ae4ce3..1d832cc25a2 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -119,6 +119,22 @@ function(ADD_ARROW_BENCHMARK REL_TEST_NAME) ${ARG_UNPARSED_ARGUMENTS}) endfunction() +macro(append_avx2_src SRC) + if(ARROW_HAVE_RUNTIME_AVX2) + list(APPEND ARROW_SRCS ${SRC}) + set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON) + set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG}) + endif() +endmacro() + +macro(append_avx512_src SRC) + if(ARROW_HAVE_RUNTIME_AVX512) + list(APPEND ARROW_SRCS ${SRC}) + set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON) + set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS ${ARROW_AVX512_FLAG}) + endif() +endmacro() + set(ARROW_SRCS array/array_base.cc array/array_binary.cc @@ -215,19 +231,9 @@ set(ARROW_SRCS vendored/double-conversion/diy-fp.cc vendored/double-conversion/strtod.cc) -if(ARROW_HAVE_RUNTIME_AVX2) - list(APPEND ARROW_SRCS util/bpacking_avx2.cc) - set_source_files_properties(util/bpacking_avx2.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON) - set_source_files_properties(util/bpacking_avx2.cc PROPERTIES COMPILE_FLAGS - ${ARROW_AVX2_FLAG}) -endif() -if(ARROW_HAVE_RUNTIME_AVX512) - list(APPEND ARROW_SRCS util/bpacking_avx512.cc) - set_source_files_properties(util/bpacking_avx512.cc PROPERTIES SKIP_PRECOMPILE_HEADERS - ON) - set_source_files_properties(util/bpacking_avx512.cc PROPERTIES COMPILE_FLAGS - ${ARROW_AVX512_FLAG}) -endif() +append_avx2_src(util/bpacking_avx2.cc) +append_avx512_src(util/bpacking_avx512.cc) + if(ARROW_HAVE_NEON) list(APPEND ARROW_SRCS util/bpacking_neon.cc) endif() @@ -397,23 +403,21 @@ if(ARROW_COMPUTE) compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc compute/kernels/vector_selection.cc - compute/kernels/vector_sort.cc) - - if(ARROW_HAVE_RUNTIME_AVX2) - list(APPEND ARROW_SRCS compute/kernels/aggregate_basic_avx2.cc) - set_source_files_properties(compute/kernels/aggregate_basic_avx2.cc PROPERTIES - SKIP_PRECOMPILE_HEADERS ON) - set_source_files_properties(compute/kernels/aggregate_basic_avx2.cc PROPERTIES - COMPILE_FLAGS ${ARROW_AVX2_FLAG}) - endif() - - if(ARROW_HAVE_RUNTIME_AVX512) - list(APPEND ARROW_SRCS compute/kernels/aggregate_basic_avx512.cc) - set_source_files_properties(compute/kernels/aggregate_basic_avx512.cc PROPERTIES - SKIP_PRECOMPILE_HEADERS ON) - set_source_files_properties(compute/kernels/aggregate_basic_avx512.cc PROPERTIES - COMPILE_FLAGS ${ARROW_AVX512_FLAG}) - endif() + compute/kernels/vector_sort.cc + compute/exec/key_hash.cc + compute/exec/key_map.cc + compute/exec/key_compare.cc + compute/exec/key_encode.cc + compute/exec/util.cc) + + append_avx2_src(compute/kernels/aggregate_basic_avx2.cc) + append_avx512_src(compute/kernels/aggregate_basic_avx512.cc) + + append_avx2_src(compute/exec/key_hash_avx2.cc) + append_avx2_src(compute/exec/key_map_avx2.cc) + append_avx2_src(compute/exec/key_compare_avx2.cc) + append_avx2_src(compute/exec/key_encode_avx2.cc) + append_avx2_src(compute/exec/util_avx2.cc) list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc) endif() diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg new file mode 100644 index 00000000000..814ad8a69f6 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg new file mode 100644 index 00000000000..7a75c96dfc5 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg new file mode 100644 index 00000000000..59bcc167ed2 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg new file mode 100644 index 00000000000..4484c57a81d Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg new file mode 100644 index 00000000000..afd33aba2e0 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg new file mode 100644 index 00000000000..f026aebe9a2 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg new file mode 100644 index 00000000000..8e1981b6571 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg new file mode 100644 index 00000000000..e976a461459 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg new file mode 100644 index 00000000000..7552d5af6af Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg new file mode 100644 index 00000000000..242f1305328 Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg b/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg new file mode 100644 index 00000000000..4c064595c9a Binary files /dev/null and b/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg differ diff --git a/cpp/src/arrow/compute/exec/doc/key_map.md b/cpp/src/arrow/compute/exec/doc/key_map.md new file mode 100644 index 00000000000..fdedc88c4d4 --- /dev/null +++ b/cpp/src/arrow/compute/exec/doc/key_map.md @@ -0,0 +1,223 @@ + + +# Swiss Table + +A specialized hash table implementation used to dynamically map combinations of key field values to a dense set of integer ids. Ids can later be used in place of keys to identify groups of rows with equal keys. + +## Introduction + +Hash group-by in Arrow uses a variant of a hash table based on a data structure called Swiss table. Swiss table uses linear probing. There is an array of slots and the information related to inserted keys is stored in these slots. A hash function determines the slot where the search for a matching key will start during hash table lookup. Then the slots are visited sequentially, wrapping around the end of an array, until either a match or an empty slot is found, the latter case meaning that there is no match. Swiss table organizes the slots in blocks of 8 and has a design that enables data level parallelism at the block level. More precisely, it allows for visiting all slots within a block at once during lookups, by simply using 64-bit arithmetic. SIMD instructions can further enhance this data level parallelism allowing to process multiple blocks related to multiple input keys together using SIMD vectors of 64-bit elements. Occupied slots within a block are always clustered together. The name Swiss table comes from likening resulting sequences of empty slots to holes in a one dimensional cheese. + +## Interface + +Hash table used in query processing for implementing join and group-by operators does not need to provide all of the operations that a general purpose hash table would. Simplified requirements can help achieve a simpler and more efficient design. For instance we do not need to be able to remove previously inserted keys. It’s an append-only data structure: new keys can be added but old keys are never erased. Also, only a single copy of each key can be inserted - it is like `std::map` in that sense and not `std::multimap`. + +Our Swiss table is fully vectorized. That means that all methods work on vectors of input keys processing them in batches. Specialized SIMD implementations of processing functions are almost always provided for performance critical operations. All callback interfaces used from the core hash table code are also designed to work on batches of inputs instead of individual keys. The batch size can be almost arbitrary and is selected by the client of the hash table. Batch size should be the smallest number of input items, big enough so that the benefits of vectorization and SIMD can be fully experienced. Keeping it small means less memory used for temporary arrays storing intermediate results of computation (vector equivalent of some temporary variables kept on the stack). That in turn means smaller space in CPU caches, which also means less impact on other memory access intensive operations. We pick 1024 as the default size of the batch. We will call it a **mini-batch** to distinguish it from potentially other forms of batches used at higher levels in the code, e.g. when scheduling work for worker threads or relational operators inside an analytic query. + +The main functionality provided by Swiss table is mapping of arbitrarily complex keys to unique integer ids. Let us call it **lookup-or-insert**. Given a sequence of key values, return a corresponding sequence of integer ids, such that all keys that are equal receive the same id and for K distinct keys the integer ids will be assigned from the set of numbers 0 to (K-1). If we find a matching key in a hash table for a given input, we return the **key id** assigned when the key was first inserted into a hash table. If we fail to find an already inserted match, we assign the first unused integer as a key id and add a new entry to a hash table. Due to vectorized processing, which may result in out-of-order processing of individual inputs, it is not guaranteed that if there are two new key values in the same input batch and one of them appears earlier in the input sequence, then it will receive a smaller key id. Additional mapping functionality can be built on top of basic mapping to integer key id, for instance if we want to assign and perhaps keep updating some values to all unique keys, we can keep these values in a resizable vector indexed by obtained key id. + +The implementation of Swiss table does not need to have any information related to the domain of the keys. It does not use their logical data type or information about their physical representation and does not even use pointers to keys. All access to keys is delegated to a separate class or classes that provide callback functions for three operations: +- computing hashes of keys; +- checking equality for given pairs of keys; +- appending a given sequence of keys to a stack maintained outside of Swiss table object, so that they can be referenced later on by key ids (key ids will be equal to their positions in the stack). + + +When passing arguments to callback functions the keys are referenced using integer ids. For the left side - that is the keys present in the input mini-batch - ordinal positions within that mini-batch are used. For the right side - that is the keys inserted into the hash table - these are identified by key ids assigned to them and stored inside Swiss table when they were first encountered and processed. + +Diagram with logical view of information passing in callbacks: + +![alt text](img/key_map_1.jpg) + +Hash table values for inserted keys are also stored inside Swiss table. Because of that, hash table logic does not need to ever re-evaluate the hash, and there is actually no need for a hash function callback. It is enough that the caller provides hash values for all entries in the batch when calling lookup-or-insert. + +## Basic architecture and organization of data +The hash table is an array of **slots**. Slots are grouped in groups of 8 called **blocks**. The number of blocks is a power of 2. The empty hash table starts with a single block, with all slots empty. Then, as the keys are getting inserted and the amount of empty slots is shrinking, at some point resizing of the hash table is triggered. The data stored in slots is moved to a new hash table that has the double of the number of blocks. + +The diagram below shows the basic organization of data in our implementation of Swiss table: + +![alt text](img/key_map_2.jpg) + +N is the log of the number of blocks, 2n+3 is the number of slots and also the maximum number of inserted keys and hence (N + 3) is the number of bits required to store a key id. We will refer to N as the **size of the hash table**. + +Index of a block within an array will be called **block id**, and similarly index of a slot will be **slot id**. Sometimes we will focus on a single block and refer to slots that belong to it by using a **local slot id**, which is an index from 0 to 7. + +Every slot can either be **empty** or store data related to a single inserted key. There are three pieces of information stored inside a slot: +- status byte, +- key id, +- key hash. + +Status byte, as the name suggests, stores 8 bits. The highest bit indicates if the slot is empty (the highest bit is set) or corresponds to one of inserted keys (the highest bit is zero). The remaining 7 bits contain 7 bits of key hash that we call a **stamp**. The stamp is used to eliminate some false positives when searching for a matching key for a given input. Slot also stores **key id**, which is a non-negative integer smaller than the number of inserted keys, that is used as a reference to the actual inserted key. The last piece of information related to an inserted key is its **hash** value. We store hashes for all keys, so that they never need to be re-computed. That greatly simplifies some operations, like resizing of a hash table, that may not even need to look at the keys at all. For an empty slot, the status byte is 0x80, key id is zero and the hash is not used and can be set to any number. + +A single block contains 8 slots and can be viewed as a micro-stack of up to 8 inserted keys. When the first key is inserted into an empty block, it will occupy a slot with local id 0. The second inserted key will go into slot number 1 and so on. We use N highest bits of hash to get an index of a **start block**, when searching for a match or an empty slot to insert a previously not seen key when that is the case. If the start block contains any empty slots, then the search for either a match or place to insert a key will end at that block. We will call such a block an **open block**. A block that is not open is a full block. In the case of full block, the input key related search may continue in the next block module the number of blocks. If the key is not inserted into its start block, we will refer to it as an **overflow** entry, other entries being **non-overflow**. Overflow entries are slower to process, since they require visiting more than one block, so we want to keep their percentage low. This is done by choosing the right **load factor** (percentage of occupied slots in the hash table) at which the hash table gets resized and the number of blocks gets doubled. By tuning this value we can control the probability of encountering an overflow entry. + +The most interesting part of each block is the set of status bytes of its slots, which is simply a single 64-bit word. The implementation of efficient searches across these bytes during lookups require using either leading zero count or trailing zero count intrinsic. Since there are cases when only the first one is available, in order to take advantage of it, we order the bytes in the 64-bit status word so that the first slot within a block uses the highest byte and the last one uses the lowest byte (slots are in reversed bytes order). The diagram below shows how the information about slots is stored within a 64-bit status word: + +![alt text](img/key_map_3.jpg) + +Each status byte has a 7-bit fragment of hash value - a **stamp** - and an empty slot bit. Empty slots have status byte equal to 0x80 - the highest bit is set to 1 to indicate an empty slot and the lowest bits, which are used by a stamp, are set to zero. + +The diagram below shows which bits of hash value are used by hash table: + +![alt text](img/key_map_4.jpg) + +If a hash table has 2N blocks, then we use N highest bits of a hash to select a start block when searching for a match. The next 7 bits are used as a stamp. Using the highest bits to pick a start block means that a range of hash values can be easily mapped to a range of block ids of start blocks for hashes in that range. This is useful when resizing a hash table or merging two hash tables together. + +### Interleaving status bytes and key ids + +Status bytes and key ids for all slots are stored in a single array of bytes. They are first grouped by 8 into blocks, then each block of status bytes is interleaved with a corresponding block of key ids. Finally key ids are represented using the smallest possible number of bits and bit-packed (bits representing each next key id start right after the last bit of the previous key id). Note that regardless of the chosen number of bits, a block of bit-packed key ids (that is 8 of them) will start and end on the byte boundary. + +The diagram below shows the organization of bytes and bits of a single block in interleaved array: +![alt text](img/key_map_5.jpg) + +From the size of the hash table we can derive the number K of bits needed in the worst case to encode any key id. K is equal to the number of bits needed to represent slot id (number of keys is not greater than the number of slots and any key id is strictly less than the number of keys), which for a hash table of size N (N blocks) equals (N+3). To simplify bit packing and unpacking and avoid handling of special cases, we will round up K to full bytes for K > 24 bits. + +Status bytes are stored in a single 64-bit word in reverse byte order (the last byte corresponds to the slot with local id 0). On the other hand key ids are stored in the normal order (the order of slot ids). + +Since both status byte and key id for a given slot are stored in the same array close to each other, we can expect that most of the lookups will read only one CPU cache-line from memory inside Swiss table code (then at least another one outside Swiss table to access the bytes of the key for the purpose of comparison). Even if we hit an overflow entry, it is still likely to reside on the same cache-line as the start block data. Hash values, which are stored separately from status byte and key id, are only used when resizing and do not impact the lookups outside these events. + +> Improvement to consider: +> In addition to the Swiss table data, we need to store an array of inserted keys, one for each key id. If keys are of fixed length, then the address of the bytes of the key can be calculated by multiplying key id by the common length of the key. If keys are of varying length, then there will be an additional array with an offset of each key within the array of concatenated bytes of keys. That means that any key comparison during lookup will involve 3 arrays: one to get key id, one to get key offset and final one with bytes of the key. This could be reduced to 2 array lookups if we stored key offset instead of key id interleaved with slot status bytes. Offset indexed by key id and stored in its own array becomes offset indexed by slot id and stored interleaved with slot status bytes. At the same time key id indexed by slot id and interleaved with slot status bytes before becomes key id referenced using offset and stored with key bytes. There may be a slight increase in the total size of memory needed by the hash table, equal to the difference in the number of bits used to store offset and those used to store key id, multiplied by the number of slots, but that should be a small fraction of the total size. + +### 32-bit hash vs 64-bit hash + +Currently we use 32-bit hash values in Swiss table code and 32-bit integers as key ids. For the robust implementation, sooner or later we will need to support 64-bit hash and 64-bit key ids. When we use 32-bit hash, it means that we run out of hash bits when hash table size N is greater than 25 (25 bits of hash needed to select a block and 7 bits needed to generate a stamp byte reach 32 total bits). When the number of inserted keys exceeds the maximal number of keys stored in a hash table of size 25 (which is at least 224), the chance of false positives during lookups will start quickly growing. 32-bit hash should not be used with more than about 16 million inserted keys. + +### Low memory footprint and low chance of hash collisions + +Swiss table is a good choice of a hash table for modern hardware, because it combines lookups that can take advantage of special CPU instructions with space efficiency and low chance of hash collisions. + +Space efficiency is important for performance, because the cost of random array accesses, often dominating the lookup cost for larger hash tables, increases with the size of the arrays. This happens due to limited space of CPU caches. Let us look at what is the amortized additional storage cost for a key in a hash table apart from the essential cost of storing data of all those keys. Furthermore, we can skip the storage of hash values, since these are only used during infrequent hash table resize operations (should not have a big impact on CPU cache usage in normal cases). + +Half full hash table of size N will use 2 status bytes per inserted key (because for every filled slot there is one empty slot) and 2\*(N+3) bits for key id (again, one for the occupied slot and one for the empty). For N = 16 for instance this is slightly under 7 bytes per inserted key. + +Swiss table also has a low probability of false positives leading to wasted key comparisons. Here is some rationale behind why this should be the case. Hash table of size N can contain up to 2N+3 keys. Search for a match involves (N + 7) hash bits: N to select a start block and 7 to use as a stamp. There are always at least 16 times more combinations of used hash bits than there are keys in the hash table (32 times more if the hash table is half full). These numbers mean that the probability of false positives resulting from a search for a matching slot should be low. That corresponds to an expected number of comparisons per lookup being close to 1 for keys already present and 0 for new keys. + +## Lookup + +Lookup-or-insert operation, given a hash of a key, finds a list of candidate slots with corresponding keys that are likely to be equal to the input key. The list may be empty, which means that the key does not exist yet in the hash table. If it is not empty, then the callback function for key comparison is called for each next candidate to verify that there is indeed a match. False positives get rejected and we end up either finding an actual match or an empty slot, which means that the key is new to the hash table. New keys get assigned next available integers as key ids, and are appended to the set of keys stored in the hash table. As a result of inserting new keys to the hash table, the density of occupied slots may reach an upper limit, at which point the hash table will be resized and will afterwards have twice as many slots. That is in summary lookup-or-insert functionality, but the actual implementation is a bit more involved, because of vectorization of the processing and various optimizations for common cases. + +### Search within a single block + +There are three possible cases that can occur when searching for a match for a given key (that is, for a given stamp of a key) within a single block, illustrated below. + + 1. There is a matching stamp in the block of status bytes: + +![alt text](img/key_map_6.jpg) + + 2. There is no matching stamp in the block, but there is an empty slot in the block: + +![alt text](img/key_map_7.jpg) + + 3. There is no matching stamp in the block and the block is full (there are no empty slots left): + +![alt text](img/key_map_8.jpg) + +64-bit arithmetic can be used to search for a matching slot within the entire single block at once, without iterating over all slots in it. Following is an example of a sequence of steps to find the first status byte for a given stamp, returning the first empty slot on miss if the block is not full or 8 (one past maximum local slot id) otherwise. + +Following is a sketch of the possible steps to execute when searching for the matching stamp in a single block. + +*Example will use input stamp 0x5E and a 64-bit status bytes word with one empty slot: +0x 4B17 5E3A 5E2B 1180*. + +1. [1 instruction] Replicate stamp to all bytes by multiplying it by 0x 0101 0101 0101 0101. + + *We obtain: 0x 5E5E 5E5E 5E5E 5E5E.* + +2. [1 instruction] XOR replicated stamp with status bytes word. Bytes corresponding to a matching stamp will be 0, bytes corresponding to empty slots will have a value between 128 and 255, bytes corresponding to non-matching non-empty slots will have a value between 1 and 127. + + *We obtain: 0x 1549 0064 0075 4FDE.* + +3. [2 instructions] In the next step we want to have information about a match in the highest bit of each byte. We can ignore here empty slot bytes, because they will be taken care of at a later step. Set the highest bit in each byte (OR with 0x 8080 8080 8080 8080) and then subtract 1 from each byte (subtract 0x 0101 0101 0101 0101 from 64-bit word). Now if a byte corresponds to a non-empty slot then the highest bit 0 indicates a match and 1 indicates a miss. + + *We obtain: 0x 95C9 80E4 80F5 CFDE, + then 0x 94C8 7FE3 7FF4 CEDD.* + +4. [3 instructions] In the next step we want to obtain in each byte one of two values: 0x80 if it is either an empty slot or a match, 0x00 otherwise. We do it in three steps: NOT the result of the previous step to change the meaning of the highest bit; OR with the original status word to set highest bit in a byte to 1 for empty slots; mask out everything other than the highest bits in all bytes (AND with 0x 8080 8080 8080 8080). + + *We obtain: 6B37 801C 800B 3122, + then 6B37 DE3E DE2B 31A2, + finally 0x0000 8000 8000 0080.* + +5. [2 instructions] Finally, use leading zero bits count and divide it by 8 to find an index of the last byte that corresponds either to a match or an empty slot. If the leading zero count intrinsic returns 64 for a 64-bit input zero, then after dividing by 8 we will also get the desired answer in case of a full block without any matches. + + *We obtain: 16, + then 2 (index of the first slot within the block that matches the stamp).* + +If SIMD instructions with 64-bit lanes are available, multiple single block searches for different keys can be executed together. For instance AVX2 instruction set allows to process quadruplets of 64-bit values in a single instruction, four searches at once. + +### Complete search potentially across multiple blocks + +Full implementation of a search for a matching key may involve visiting multiple blocks beginning with the start block selected based on the hash of the key. We move to the next block modulo the number of blocks, whenever we do not find a match in the current block and the current block is full. The search may also involve visiting one or more slots in each block. Visiting in this case means calling a comparison callback to verify the match whenever a slot with a matching stamp is encountered. Eventually the search stops when either: +- the matching key is found in one of the slots matching the stamp, or + +- an empty slot is reached. This is illustrated in the diagram below: +![alt text](img/key_map_9.jpg) + + +### Optimistic processing with two passes + +Hash table lookups may have high cost in the pessimistic case, when we encounter cases of hash collisions and full blocks that lead to visiting further blocks. In the majority of cases we can expect an optimistic situation - the start block is not full, so we will only visit this one block, and all stamps in the block are different, so we will need at most one comparison to find a match. We can expect about 90% of the key lookups for an existing key to go through the optimistic path of processing. For that reason it pays off to optimize especially for this 90% of inputs. + +Lookups in Swiss table are split into two passes over an input batch of keys. The **first pass: fast-path lookup** , is a highly optimized, vectorized, SIMD-friendly, branch-free code that fully handles optimistic cases. The **second pass: slow-path lookup** , is normally executed only for the selection of inputs that have not been finished in the first pass, although it can also be called directly on all of the inputs, skipping fast-path lookup. It handles all special cases and inserts but in order to be robust it is not as efficient as fast-path. Slow-path lookup does not need to repeat the work done in fast-path lookup - it can use the state reached at the end of fast-path lookup as a starting point. + +Fast-path lookup implements search only for the first stamp match and only within the start block. It only makes sense when we already have at least one key inserted into the hash table, since it does not handle inserts. It takes a vector of key hashes as an input and based on it outputs three pieces of information for each key: + +- Key id corresponding to the slot in which a matching stamp was found. Any valid key id if a matching stamp was not found. +- A flag indicating if a match was found or not. +- Slot id of a slot from which slow-path should pick up the search if the first match was either not found or it turns out to be false positive after evaluating key comparison. + +> Improvement to consider: +> precomputing 1st pass lookup results. +> +> If the hash table is small, the number of inserted keys is small, we could further simplify and speed-up the first pass by storing in a lookup table pre-computed results for all combinations of hash bits. Let us consider the case of Swiss table of size 5 that has 256 slots and up to 128 inserted keys. Only 12 bits of hash are used by lookup in that case: 5 to select a block, 7 to create a stamp. For all 212 combinations of those bits we could keep the result of first pass lookup in an array. Key id and a match indicating flag can use one byte: 7 bits for key id and 1 bit for the flag. Note that slot id is only needed if we go into 2nd pass lookup, so it can be stored separately and likely only accessed by a small subset of keys. Fast-path lookup becomes almost a single fetch of result from a 4KB array. Lookup arrays used to implement this need to be kept in sync with the main copy of data about slots, which requires extra care during inserts. Since the number of entries in lookup arrays is much higher than the number of slots, this technique only makes sense for small hash tables. + +### Dense comparisons + +If there is at least one key inserted into a hash table, then every slot contains a key id value that corresponds to some actual key that can be used in comparison. That is because empty slots are initialized with 0 as their key id. After the fast-path lookup we get a match-found flag for each input. If it is set, then we need to run a comparison of the input key with the key in the hash table identified by key id returned by fast-path code. The comparison will verify that there is a true match between the keys. We only need to do this for a subset of inputs that have a match candidate, but since we have key id values corresponding to some real key for all inputs, we may as well execute comparisons on all inputs unconditionally. If the majority (e.g. more than 80%) of the keys have a match candidate, the cost of evaluating comparison for the remaining fraction of keys but without filtering may actually be cheaper than the cost of running evaluation only for required keys while referencing filter information. This can be seen as a variant of general preconditioning techniques used to avoid diverging conditional branches in the code. It may be used, based on some heuristic, to verify matches reported by fast-path lookups and is referred to as **dense comparisons**. + +## Resizing + +New hash table is initialized as empty and has only a single block with a space for only a few key entries. Doubling of the hash table size becomes necessary as more keys get inserted. It is invoked during the 2nd pass of the lookups, which also handles inserts. It happens immediately after the number of inserted keys reaches a specific upper limit decided based on a current size of the hash table. There may still be unprocessed entries from the input mini-batch after resizing, so the 2nd pass of the lookup is restarted right after, with the bigger hash table and the remaining subset of unprocessed entries. + +Current policy, that should work reasonably well, is to resize a small hash table (up to 8KB) when it is 50% full. Larger hash tables are resized when 75% full. We want to keep size in memory as small as possible, while maintaining a low probability of blocks becoming full. + +When discussing resizing we will be talking about **resize source** and **resize target** tables. The diagram below shows how the same hash bits are interpreted differently by the source and the target. + +![alt text](img/key_map_10.jpg) + +For a given hash, if a start block id was L in the source table, it will be either (2\*L+0) or (2\*L+1) in the target table. Based on that we can expect data access locality when migrating the data between the tables. + +Resizing is cheap also thanks to the fact that hash values for keys in the hash table are kept together with other slot data and do not need to be recomputed. That means that resizing procedure does not ever need to access the actual bytes of the key. + +### 1st pass + +Based on the hash value for a given slot we can tell whether this slot contains an overflow or non-overflow entry. In the first pass we go over all source slots in sequence, filter out overflow entries and move to the target table all other entries. Non-overflow entries from a block L will be distributed between blocks (2\*L+0) and (2\*L+1) of the target table. None of these target blocks can overflow, since they will be accommodating at most 8 input entries during this pass. + +For every non-overflow entry, the highest bit of a stamp in the source slot decides whether it will go to the left or to the right target block. It is further possible to avoid any conditional branches in this partitioning code, so that the result is friendly to the CPU execution pipeline. + +![alt text](img/key_map_11.jpg) + + +### 2nd pass + +In the second pass of resizing, we scan all source slots again, this time focusing only on the overflow entries that were all skipped in the 1st pass. We simply reinsert them in the target table using generic insertion code with one exception. Since we know that all the source keys are different, there is no need to search for a matching stamp or run key comparisons (or look at the key values). We just need to find the first open block beginning with the start block in the target table and use its first empty slot as the insert destination. + +We expect overflow entries to be rare and therefore the relative cost of that pass should stay low. + diff --git a/cpp/src/arrow/compute/exec/key_compare.cc b/cpp/src/arrow/compute/exec/key_compare.cc new file mode 100644 index 00000000000..f8d74859b01 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_compare.cc @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/key_compare.h" + +#include +#include + +#include "arrow/compute/exec/util.h" + +namespace arrow { +namespace compute { + +void KeyCompare::CompareRows(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, uint32_t* out_num_rows, + uint16_t* out_sel_left_maybe_same, + const KeyEncoder::KeyRowArray& rows_left, + const KeyEncoder::KeyRowArray& rows_right) { + ARROW_DCHECK(rows_left.metadata().is_compatible(rows_right.metadata())); + + if (num_rows_to_compare == 0) { + *out_num_rows = 0; + return; + } + + // Allocate temporary byte and bit vectors + auto bytevector_holder = + util::TempVectorHolder(ctx->stack, num_rows_to_compare); + auto bitvector_holder = + util::TempVectorHolder(ctx->stack, num_rows_to_compare); + + uint8_t* match_bytevector = bytevector_holder.mutable_data(); + uint8_t* match_bitvector = bitvector_holder.mutable_data(); + + // All comparison functions called here will update match byte vector + // (AND it with comparison result) instead of overwriting it. + memset(match_bytevector, 0xff, num_rows_to_compare); + + if (rows_left.metadata().is_fixed_length) { + CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, left_to_right_map, + match_bytevector, ctx, rows_left.metadata().fixed_length, + rows_left.data(1), rows_right.data(1)); + } else { + CompareVaryingLength(num_rows_to_compare, sel_left_maybe_null, left_to_right_map, + match_bytevector, ctx, rows_left.data(2), rows_right.data(2), + rows_left.offsets(), rows_right.offsets()); + } + + // CompareFixedLength can be used to compare nulls as well + bool nulls_present = rows_left.has_any_nulls(ctx) || rows_right.has_any_nulls(ctx); + if (nulls_present) { + CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, left_to_right_map, + match_bytevector, ctx, + rows_left.metadata().null_masks_bytes_per_row, + rows_left.null_masks(), rows_right.null_masks()); + } + + util::BitUtil::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare, match_bytevector, + match_bitvector); + if (sel_left_maybe_null) { + int out_num_rows_int; + util::BitUtil::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, sel_left_maybe_null, + &out_num_rows_int, out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } else { + int out_num_rows_int; + util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, &out_num_rows_int, + out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } +} + +void KeyCompare::CompareFixedLength(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, + KeyEncoder::KeyEncoderContext* ctx, + uint32_t fixed_length, const uint8_t* rows_left, + const uint8_t* rows_right) { + bool use_selection = (sel_left_maybe_null != nullptr); + + uint32_t num_rows_already_processed = 0; + +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2() && !use_selection) { + // Choose between up-to-8B length, up-to-16B length and any size versions + if (fixed_length <= 8) { + num_rows_already_processed = CompareFixedLength_UpTo8B_avx2( + num_rows_to_compare, left_to_right_map, match_bytevector, fixed_length, + rows_left, rows_right); + } else if (fixed_length <= 16) { + num_rows_already_processed = CompareFixedLength_UpTo16B_avx2( + num_rows_to_compare, left_to_right_map, match_bytevector, fixed_length, + rows_left, rows_right); + } else { + num_rows_already_processed = + CompareFixedLength_avx2(num_rows_to_compare, left_to_right_map, + match_bytevector, fixed_length, rows_left, rows_right); + } + } +#endif + + typedef void (*CompareFixedLengthImp_t)(uint32_t, uint32_t, const uint16_t*, + const uint32_t*, uint8_t*, uint32_t, + const uint8_t*, const uint8_t*); + static const CompareFixedLengthImp_t CompareFixedLengthImp_fn[] = { + CompareFixedLengthImp, CompareFixedLengthImp, + CompareFixedLengthImp, CompareFixedLengthImp, + CompareFixedLengthImp, CompareFixedLengthImp}; + int dispatch_const = (use_selection ? 3 : 0) + + ((fixed_length <= 8) ? 0 : ((fixed_length <= 16) ? 1 : 2)); + CompareFixedLengthImp_fn[dispatch_const]( + num_rows_already_processed, num_rows_to_compare, sel_left_maybe_null, + left_to_right_map, match_bytevector, fixed_length, rows_left, rows_right); +} + +template +void KeyCompare::CompareFixedLengthImp(uint32_t num_rows_already_processed, + uint32_t num_rows, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, uint32_t length, + const uint8_t* rows_left, + const uint8_t* rows_right) { + // Key length (for encoded key) has to be non-zero + ARROW_DCHECK(length > 0); + + // Non-zero length guarantees no underflow + int32_t num_loops_less_one = (static_cast(length) + 7) / 8 - 1; + + // Length remaining in last loop can only be zero for input length equal to zero + uint32_t length_remaining_last_loop = length - num_loops_less_one * 8; + uint64_t tail_mask = (~0ULL) >> (8 * (8 - length_remaining_last_loop)); + + for (uint32_t id_input = num_rows_already_processed; id_input < num_rows; ++id_input) { + uint32_t irow_left = use_selection ? sel_left_maybe_null[id_input] : id_input; + uint32_t irow_right = left_to_right_map[irow_left]; + uint32_t begin_left = length * irow_left; + uint32_t begin_right = length * irow_right; + const uint64_t* key_left_ptr = + reinterpret_cast(rows_left + begin_left); + const uint64_t* key_right_ptr = + reinterpret_cast(rows_right + begin_right); + uint64_t result_or = 0ULL; + int32_t istripe = 0; + + // Specializations for keys up to 8 bytes and between 9 and 16 bytes to + // avoid internal loop over words in the value for short ones. + // + // Template argument 0 means arbitrarily many 64-bit words, + // 1 means up to 1 and 2 means up to 2. + // + if (num_64bit_words == 0) { + for (; istripe < num_loops_less_one; ++istripe) { + uint64_t key_left = key_left_ptr[istripe]; + uint64_t key_right = key_right_ptr[istripe]; + result_or |= (key_left ^ key_right); + } + } else if (num_64bit_words == 2) { + uint64_t key_left = key_left_ptr[istripe]; + uint64_t key_right = key_right_ptr[istripe]; + result_or |= (key_left ^ key_right); + ++istripe; + } + + uint64_t key_left = key_left_ptr[istripe]; + uint64_t key_right = key_right_ptr[istripe]; + result_or |= (tail_mask & (key_left ^ key_right)); + + int result = (result_or == 0 ? 0xff : 0); + match_bytevector[id_input] &= result; + } +} + +void KeyCompare::CompareVaryingLength(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, + KeyEncoder::KeyEncoderContext* ctx, + const uint8_t* rows_left, const uint8_t* rows_right, + const uint32_t* offsets_left, + const uint32_t* offsets_right) { + bool use_selection = (sel_left_maybe_null != nullptr); + +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2() && !use_selection) { + CompareVaryingLength_avx2(num_rows_to_compare, left_to_right_map, match_bytevector, + rows_left, rows_right, offsets_left, offsets_right); + } else { +#endif + if (use_selection) { + CompareVaryingLengthImp(num_rows_to_compare, sel_left_maybe_null, + left_to_right_map, match_bytevector, rows_left, + rows_right, offsets_left, offsets_right); + } else { + CompareVaryingLengthImp(num_rows_to_compare, sel_left_maybe_null, + left_to_right_map, match_bytevector, rows_left, + rows_right, offsets_left, offsets_right); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +template +void KeyCompare::CompareVaryingLengthImp( + uint32_t num_rows, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, uint8_t* match_bytevector, + const uint8_t* rows_left, const uint8_t* rows_right, const uint32_t* offsets_left, + const uint32_t* offsets_right) { + static const uint64_t tail_masks[] = { + 0x0000000000000000ULL, 0x00000000000000ffULL, 0x000000000000ffffULL, + 0x0000000000ffffffULL, 0x00000000ffffffffULL, 0x000000ffffffffffULL, + 0x0000ffffffffffffULL, 0x00ffffffffffffffULL, 0xffffffffffffffffULL}; + for (uint32_t i = 0; i < num_rows; ++i) { + uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; + uint32_t irow_right = left_to_right_map[irow_left]; + uint32_t begin_left = offsets_left[irow_left]; + uint32_t begin_right = offsets_right[irow_right]; + uint32_t length_left = offsets_left[irow_left + 1] - begin_left; + uint32_t length_right = offsets_right[irow_right + 1] - begin_right; + uint32_t length = std::min(length_left, length_right); + const uint64_t* key_left_ptr = + reinterpret_cast(rows_left + begin_left); + const uint64_t* key_right_ptr = + reinterpret_cast(rows_right + begin_right); + uint64_t result_or = 0; + int32_t istripe; + // length can be zero + for (istripe = 0; istripe < (static_cast(length) + 7) / 8 - 1; ++istripe) { + uint64_t key_left = key_left_ptr[istripe]; + uint64_t key_right = key_right_ptr[istripe]; + result_or |= (key_left ^ key_right); + } + + uint32_t length_remaining = length - static_cast(istripe) * 8; + uint64_t tail_mask = tail_masks[length_remaining]; + + uint64_t key_left = key_left_ptr[istripe]; + uint64_t key_right = key_right_ptr[istripe]; + result_or |= (tail_mask & (key_left ^ key_right)); + + int result = (result_or == 0 ? 0xff : 0); + match_bytevector[i] &= result; + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_compare.h b/cpp/src/arrow/compute/exec/key_compare.h new file mode 100644 index 00000000000..1dffabb884b --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_compare.h @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace compute { + +class KeyCompare { + public: + // Returns a single 16-bit selection vector of rows that failed comparison. + // If there is input selection on the left, the resulting selection is a filtered image + // of input selection. + static void CompareRows(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, uint32_t* out_num_rows, + uint16_t* out_sel_left_maybe_same, + const KeyEncoder::KeyRowArray& rows_left, + const KeyEncoder::KeyRowArray& rows_right); + + private: + static void CompareFixedLength(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, + KeyEncoder::KeyEncoderContext* ctx, + uint32_t fixed_length, const uint8_t* rows_left, + const uint8_t* rows_right); + static void CompareVaryingLength(uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, + KeyEncoder::KeyEncoderContext* ctx, + const uint8_t* rows_left, const uint8_t* rows_right, + const uint32_t* offsets_left, + const uint32_t* offsets_right); + + // Second template argument is 0, 1 or 2. + // 0 means arbitrarily many 64-bit words, 1 means up to 1 and 2 means up to 2. + template + static void CompareFixedLengthImp(uint32_t num_rows_already_processed, + uint32_t num_rows, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, uint32_t length, + const uint8_t* rows_left, const uint8_t* rows_right); + template + static void CompareVaryingLengthImp(uint32_t num_rows, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, const uint8_t* rows_left, + const uint8_t* rows_right, + const uint32_t* offsets_left, + const uint32_t* offsets_right); + +#if defined(ARROW_HAVE_AVX2) + + static uint32_t CompareFixedLength_UpTo8B_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right); + static uint32_t CompareFixedLength_UpTo16B_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right); + static uint32_t CompareFixedLength_avx2(uint32_t num_rows, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, uint32_t length, + const uint8_t* rows_left, + const uint8_t* rows_right); + static void CompareVaryingLength_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + const uint8_t* rows_left, const uint8_t* rows_right, const uint32_t* offsets_left, + const uint32_t* offsets_right); + +#endif +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_compare_avx2.cc b/cpp/src/arrow/compute/exec/key_compare_avx2.cc new file mode 100644 index 00000000000..6abdf6c3c3a --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_compare_avx2.cc @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/key_compare.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +uint32_t KeyCompare::CompareFixedLength_UpTo8B_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right) { + ARROW_DCHECK(length <= 8); + __m256i offset_left = _mm256_setr_epi64x(0, length, length * 2, length * 3); + __m256i offset_left_incr = _mm256_set1_epi64x(length * 4); + __m256i mask = _mm256_set1_epi64x(~0ULL >> (8 * (8 - length))); + + constexpr uint32_t unroll = 4; + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + auto key_left = _mm256_i64gather_epi64( + reinterpret_cast(rows_left), offset_left, 1); + offset_left = _mm256_add_epi64(offset_left, offset_left_incr); + __m128i offset_right = + _mm_loadu_si128(reinterpret_cast(left_to_right_map) + i); + offset_right = _mm_mullo_epi32(offset_right, _mm_set1_epi32(length)); + + auto key_right = _mm256_i32gather_epi64( + reinterpret_cast(rows_right), offset_right, 1); + uint32_t cmp = _mm256_movemask_epi8(_mm256_cmpeq_epi64( + _mm256_and_si256(key_left, mask), _mm256_and_si256(key_right, mask))); + reinterpret_cast(match_bytevector)[i] &= cmp; + } + + uint32_t num_rows_processed = num_rows - (num_rows % unroll); + return num_rows_processed; +} + +uint32_t KeyCompare::CompareFixedLength_UpTo16B_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right) { + ARROW_DCHECK(length <= 16); + + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + + __m256i mask = + _mm256_cmpgt_epi8(_mm256_set1_epi8(length), + _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, + kByteSequence0To7, kByteSequence8To15)); + const uint8_t* key_left_ptr = rows_left; + + constexpr uint32_t unroll = 2; + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + auto key_left = _mm256_inserti128_si256( + _mm256_castsi128_si256( + _mm_loadu_si128(reinterpret_cast(key_left_ptr))), + _mm_loadu_si128(reinterpret_cast(key_left_ptr + length)), 1); + key_left_ptr += length * 2; + auto key_right = _mm256_inserti128_si256( + _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast( + rows_right + length * left_to_right_map[2 * i]))), + _mm_loadu_si128(reinterpret_cast( + rows_right + length * left_to_right_map[2 * i + 1])), + 1); + __m256i cmp = _mm256_cmpeq_epi64(_mm256_and_si256(key_left, mask), + _mm256_and_si256(key_right, mask)); + cmp = _mm256_and_si256(cmp, _mm256_shuffle_epi32(cmp, 0xee)); // 0b11101110 + cmp = _mm256_permute4x64_epi64(cmp, 0x08); // 0b00001000 + reinterpret_cast(match_bytevector)[i] &= + (_mm256_movemask_epi8(cmp) & 0xffff); + } + + uint32_t num_rows_processed = num_rows - (num_rows % unroll); + return num_rows_processed; +} + +uint32_t KeyCompare::CompareFixedLength_avx2(uint32_t num_rows, + const uint32_t* left_to_right_map, + uint8_t* match_bytevector, uint32_t length, + const uint8_t* rows_left, + const uint8_t* rows_right) { + ARROW_DCHECK(length > 0); + + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL; + constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL; + + // Non-zero length guarantees no underflow + int32_t num_loops_less_one = (static_cast(length) + 31) / 32 - 1; + + __m256i tail_mask = + _mm256_cmpgt_epi8(_mm256_set1_epi8(length - num_loops_less_one * 32), + _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, + kByteSequence16To23, kByteSequence24To31)); + + for (uint32_t irow_left = 0; irow_left < num_rows; ++irow_left) { + uint32_t irow_right = left_to_right_map[irow_left]; + uint32_t begin_left = length * irow_left; + uint32_t begin_right = length * irow_right; + const __m256i* key_left_ptr = + reinterpret_cast(rows_left + begin_left); + const __m256i* key_right_ptr = + reinterpret_cast(rows_right + begin_right); + __m256i result_or = _mm256_setzero_si256(); + int32_t i; + // length cannot be zero + for (i = 0; i < num_loops_less_one; ++i) { + __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); + __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); + result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right)); + } + + __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); + __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); + result_or = _mm256_or_si256( + result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right))); + int result = _mm256_testz_si256(result_or, result_or) * 0xff; + match_bytevector[irow_left] &= result; + } + + uint32_t num_rows_processed = num_rows; + return num_rows_processed; +} + +void KeyCompare::CompareVaryingLength_avx2( + uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, + const uint8_t* rows_left, const uint8_t* rows_right, const uint32_t* offsets_left, + const uint32_t* offsets_right) { + for (uint32_t irow_left = 0; irow_left < num_rows; ++irow_left) { + uint32_t irow_right = left_to_right_map[irow_left]; + uint32_t begin_left = offsets_left[irow_left]; + uint32_t begin_right = offsets_right[irow_right]; + uint32_t length_left = offsets_left[irow_left + 1] - begin_left; + uint32_t length_right = offsets_right[irow_right + 1] - begin_right; + uint32_t length = std::min(length_left, length_right); + auto key_left_ptr = reinterpret_cast(rows_left + begin_left); + auto key_right_ptr = reinterpret_cast(rows_right + begin_right); + __m256i result_or = _mm256_setzero_si256(); + int32_t i; + // length can be zero + for (i = 0; i < (static_cast(length) + 31) / 32 - 1; ++i) { + __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); + __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); + result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right)); + } + + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL; + constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL; + + __m256i tail_mask = + _mm256_cmpgt_epi8(_mm256_set1_epi8(length - i * 32), + _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, + kByteSequence16To23, kByteSequence24To31)); + + __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); + __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); + result_or = _mm256_or_si256( + result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right))); + int result = _mm256_testz_si256(result_or, result_or) * 0xff; + match_bytevector[irow_left] &= result; + } +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc new file mode 100644 index 00000000000..0c5f27c51c1 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -0,0 +1,1625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/key_encode.h" + +#include + +#include + +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/ubsan.h" + +namespace arrow { +namespace compute { + +KeyEncoder::KeyRowArray::KeyRowArray() + : pool_(nullptr), rows_capacity_(0), bytes_capacity_(0) {} + +Status KeyEncoder::KeyRowArray::Init(MemoryPool* pool, const KeyRowMetadata& metadata) { + pool_ = pool; + metadata_ = metadata; + + ARROW_DCHECK(!null_masks_ && !offsets_ && !rows_); + + constexpr int64_t rows_capacity = 8; + constexpr int64_t bytes_capacity = 1024; + + // Null masks + ARROW_ASSIGN_OR_RAISE(auto null_masks, + AllocateResizableBuffer(size_null_masks(rows_capacity), pool_)); + null_masks_ = std::move(null_masks); + memset(null_masks_->mutable_data(), 0, size_null_masks(rows_capacity)); + + // Offsets and rows + if (!metadata.is_fixed_length) { + ARROW_ASSIGN_OR_RAISE(auto offsets, + AllocateResizableBuffer(size_offsets(rows_capacity), pool_)); + offsets_ = std::move(offsets); + memset(offsets_->mutable_data(), 0, size_offsets(rows_capacity)); + reinterpret_cast(offsets_->mutable_data())[0] = 0; + + ARROW_ASSIGN_OR_RAISE( + auto rows, + AllocateResizableBuffer(size_rows_varying_length(bytes_capacity), pool_)); + rows_ = std::move(rows); + memset(rows_->mutable_data(), 0, size_rows_varying_length(bytes_capacity)); + bytes_capacity_ = size_rows_varying_length(bytes_capacity) - padding_for_vectors; + } else { + ARROW_ASSIGN_OR_RAISE( + auto rows, AllocateResizableBuffer(size_rows_fixed_length(rows_capacity), pool_)); + rows_ = std::move(rows); + memset(rows_->mutable_data(), 0, size_rows_fixed_length(rows_capacity)); + bytes_capacity_ = size_rows_fixed_length(rows_capacity) - padding_for_vectors; + } + + update_buffer_pointers(); + + rows_capacity_ = rows_capacity; + + num_rows_ = 0; + num_rows_for_has_any_nulls_ = 0; + has_any_nulls_ = false; + + return Status::OK(); +} + +void KeyEncoder::KeyRowArray::Clean() { + num_rows_ = 0; + num_rows_for_has_any_nulls_ = 0; + has_any_nulls_ = false; + + if (!metadata_.is_fixed_length) { + reinterpret_cast(offsets_->mutable_data())[0] = 0; + } +} + +int64_t KeyEncoder::KeyRowArray::size_null_masks(int64_t num_rows) { + return num_rows * metadata_.null_masks_bytes_per_row + padding_for_vectors; +} + +int64_t KeyEncoder::KeyRowArray::size_offsets(int64_t num_rows) { + return (num_rows + 1) * sizeof(uint32_t) + padding_for_vectors; +} + +int64_t KeyEncoder::KeyRowArray::size_rows_fixed_length(int64_t num_rows) { + return num_rows * metadata_.fixed_length + padding_for_vectors; +} + +int64_t KeyEncoder::KeyRowArray::size_rows_varying_length(int64_t num_bytes) { + return num_bytes + padding_for_vectors; +} + +void KeyEncoder::KeyRowArray::update_buffer_pointers() { + buffers_[0] = mutable_buffers_[0] = null_masks_->mutable_data(); + if (metadata_.is_fixed_length) { + buffers_[1] = mutable_buffers_[1] = rows_->mutable_data(); + buffers_[2] = mutable_buffers_[2] = nullptr; + } else { + buffers_[1] = mutable_buffers_[1] = offsets_->mutable_data(); + buffers_[2] = mutable_buffers_[2] = rows_->mutable_data(); + } +} + +Status KeyEncoder::KeyRowArray::ResizeFixedLengthBuffers(int64_t num_extra_rows) { + if (rows_capacity_ >= num_rows_ + num_extra_rows) { + return Status::OK(); + } + + int64_t rows_capacity_new = std::max(static_cast(1), 2 * rows_capacity_); + while (rows_capacity_new < num_rows_ + num_extra_rows) { + rows_capacity_new *= 2; + } + + // Null masks + RETURN_NOT_OK(null_masks_->Resize(size_null_masks(rows_capacity_new), false)); + memset(null_masks_->mutable_data() + size_null_masks(rows_capacity_), 0, + size_null_masks(rows_capacity_new) - size_null_masks(rows_capacity_)); + + // Either offsets or rows + if (!metadata_.is_fixed_length) { + RETURN_NOT_OK(offsets_->Resize(size_offsets(rows_capacity_new), false)); + memset(offsets_->mutable_data() + size_offsets(rows_capacity_), 0, + size_offsets(rows_capacity_new) - size_offsets(rows_capacity_)); + } else { + RETURN_NOT_OK(rows_->Resize(size_rows_fixed_length(rows_capacity_new), false)); + memset(rows_->mutable_data() + size_rows_fixed_length(rows_capacity_), 0, + size_rows_fixed_length(rows_capacity_new) - + size_rows_fixed_length(rows_capacity_)); + bytes_capacity_ = size_rows_fixed_length(rows_capacity_new) - padding_for_vectors; + } + + update_buffer_pointers(); + + rows_capacity_ = rows_capacity_new; + + return Status::OK(); +} + +Status KeyEncoder::KeyRowArray::ResizeOptionalVaryingLengthBuffer( + int64_t num_extra_bytes) { + int64_t num_bytes = offsets()[num_rows_]; + if (bytes_capacity_ >= num_bytes + num_extra_bytes || metadata_.is_fixed_length) { + return Status::OK(); + } + + int64_t bytes_capacity_new = std::max(static_cast(1), 2 * bytes_capacity_); + while (bytes_capacity_new < num_bytes + num_extra_bytes) { + bytes_capacity_new *= 2; + } + + RETURN_NOT_OK(rows_->Resize(size_rows_varying_length(bytes_capacity_new), false)); + memset(rows_->mutable_data() + size_rows_varying_length(bytes_capacity_), 0, + size_rows_varying_length(bytes_capacity_new) - + size_rows_varying_length(bytes_capacity_)); + + update_buffer_pointers(); + + bytes_capacity_ = bytes_capacity_new; + + return Status::OK(); +} + +Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, + uint32_t num_rows_to_append, + const uint16_t* source_row_ids) { + ARROW_DCHECK(metadata_.is_compatible(from.metadata())); + + RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); + + if (!metadata_.is_fixed_length) { + // Varying-length rows + const uint32_t* from_offsets = + reinterpret_cast(from.offsets_->data()); + uint32_t* to_offsets = reinterpret_cast(offsets_->mutable_data()); + uint32_t total_length = to_offsets[num_rows_]; + uint32_t total_length_to_append = 0; + for (uint32_t i = 0; i < num_rows_to_append; ++i) { + uint16_t row_id = source_row_ids[i]; + uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + total_length_to_append += length; + to_offsets[num_rows_ + i + 1] = total_length + total_length_to_append; + } + + RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(total_length_to_append)); + + const uint8_t* src = from.rows_->data(); + uint8_t* dst = rows_->mutable_data() + total_length; + for (uint32_t i = 0; i < num_rows_to_append; ++i) { + uint16_t row_id = source_row_ids[i]; + uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + const uint64_t* src64 = + reinterpret_cast(src + from_offsets[row_id]); + uint64_t* dst64 = reinterpret_cast(dst); + for (uint32_t j = 0; j < (length + 7) / 8; ++j) { + dst64[j] = src64[j]; + } + dst += length; + } + } else { + // Fixed-length rows + const uint8_t* src = from.rows_->data(); + uint8_t* dst = rows_->mutable_data() + num_rows_ * metadata_.fixed_length; + for (uint32_t i = 0; i < num_rows_to_append; ++i) { + uint16_t row_id = source_row_ids[i]; + uint32_t length = metadata_.fixed_length; + const uint64_t* src64 = reinterpret_cast(src + length * row_id); + uint64_t* dst64 = reinterpret_cast(dst); + for (uint32_t j = 0; j < (length + 7) / 8; ++j) { + dst64[j] = src64[j]; + } + dst += length; + } + } + + // Null masks + uint32_t byte_length = metadata_.null_masks_bytes_per_row; + uint64_t dst_byte_offset = num_rows_ * byte_length; + const uint8_t* src_base = from.null_masks_->data(); + uint8_t* dst_base = null_masks_->mutable_data(); + for (uint32_t i = 0; i < num_rows_to_append; ++i) { + uint32_t row_id = source_row_ids[i]; + int64_t src_byte_offset = row_id * byte_length; + const uint8_t* src = src_base + src_byte_offset; + uint8_t* dst = dst_base + dst_byte_offset; + for (uint32_t ibyte = 0; ibyte < byte_length; ++ibyte) { + dst[ibyte] = src[ibyte]; + } + dst_byte_offset += byte_length; + } + + num_rows_ += num_rows_to_append; + + return Status::OK(); +} + +Status KeyEncoder::KeyRowArray::AppendEmpty(uint32_t num_rows_to_append, + uint32_t num_extra_bytes_to_append) { + RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); + RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(num_extra_bytes_to_append)); + num_rows_ += num_rows_to_append; + if (metadata_.row_alignment > 1 || metadata_.string_alignment > 1) { + memset(rows_->mutable_data(), 0, bytes_capacity_); + } + return Status::OK(); +} + +bool KeyEncoder::KeyRowArray::has_any_nulls(const KeyEncoderContext* ctx) const { + if (has_any_nulls_) { + return true; + } + if (num_rows_for_has_any_nulls_ < num_rows_) { + auto size_per_row = metadata().null_masks_bytes_per_row; + has_any_nulls_ = !util::BitUtil::are_all_bytes_zero( + ctx->hardware_flags, null_masks() + size_per_row * num_rows_for_has_any_nulls_, + static_cast(size_per_row * (num_rows_ - num_rows_for_has_any_nulls_))); + num_rows_for_has_any_nulls_ = num_rows_; + } + return has_any_nulls_; +} + +KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata, + const KeyColumnArray& left, + const KeyColumnArray& right, + int buffer_id_to_replace) { + metadata_ = metadata; + length_ = left.length(); + for (int i = 0; i < max_buffers_; ++i) { + buffers_[i] = left.buffers_[i]; + mutable_buffers_[i] = left.mutable_buffers_[i]; + } + buffers_[buffer_id_to_replace] = right.buffers_[buffer_id_to_replace]; + mutable_buffers_[buffer_id_to_replace] = right.mutable_buffers_[buffer_id_to_replace]; +} + +KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata, + int64_t length, const uint8_t* buffer0, + const uint8_t* buffer1, + const uint8_t* buffer2) { + metadata_ = metadata; + length_ = length; + buffers_[0] = buffer0; + buffers_[1] = buffer1; + buffers_[2] = buffer2; + mutable_buffers_[0] = mutable_buffers_[1] = mutable_buffers_[2] = nullptr; +} + +KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata, + int64_t length, uint8_t* buffer0, + uint8_t* buffer1, uint8_t* buffer2) { + metadata_ = metadata; + length_ = length; + buffers_[0] = mutable_buffers_[0] = buffer0; + buffers_[1] = mutable_buffers_[1] = buffer1; + buffers_[2] = mutable_buffers_[2] = buffer2; +} + +KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnArray& from, int64_t start, + int64_t length) { + ARROW_DCHECK((start % 8) == 0); + metadata_ = from.metadata_; + length_ = length; + uint32_t fixed_size = + !metadata_.is_fixed_length ? sizeof(uint32_t) : metadata_.fixed_length; + + buffers_[0] = from.buffers_[0] ? from.buffers_[0] + start / 8 : nullptr; + mutable_buffers_[0] = + from.mutable_buffers_[0] ? from.mutable_buffers_[0] + start / 8 : nullptr; + + if (fixed_size == 0) { + buffers_[1] = from.buffers_[1] ? from.buffers_[1] + start / 8 : nullptr; + mutable_buffers_[1] = + from.mutable_buffers_[1] ? from.mutable_buffers_[1] + start / 8 : nullptr; + } else { + buffers_[1] = from.buffers_[1] ? from.buffers_[1] + start * fixed_size : nullptr; + mutable_buffers_[1] = from.mutable_buffers_[1] + ? from.mutable_buffers_[1] + start * fixed_size + : nullptr; + } + + buffers_[2] = from.buffers_[2]; + mutable_buffers_[2] = from.mutable_buffers_[2]; +} + +KeyEncoder::KeyColumnArray KeyEncoder::TransformBoolean::ArrayReplace( + const KeyColumnArray& column, const KeyColumnArray& temp) { + // Make sure that the temp buffer is large enough + ARROW_DCHECK(temp.length() >= column.length() && temp.metadata().is_fixed_length && + temp.metadata().fixed_length >= sizeof(uint8_t)); + KeyColumnMetadata metadata; + metadata.is_fixed_length = true; + metadata.fixed_length = sizeof(uint8_t); + constexpr int buffer_index = 1; + KeyColumnArray result = KeyColumnArray(metadata, column, temp, buffer_index); + return result; +} + +void KeyEncoder::TransformBoolean::PreEncode(const KeyColumnArray& input, + KeyColumnArray* output, + KeyEncoderContext* ctx) { + // Make sure that metadata and lengths are compatible. + ARROW_DCHECK(output->metadata().is_fixed_length == input.metadata().is_fixed_length); + ARROW_DCHECK(output->metadata().fixed_length == 1 && + input.metadata().fixed_length == 0); + ARROW_DCHECK(output->length() == input.length()); + constexpr int buffer_index = 1; + ARROW_DCHECK(input.data(buffer_index) != nullptr); + ARROW_DCHECK(output->mutable_data(buffer_index) != nullptr); + util::BitUtil::bits_to_bytes(ctx->hardware_flags, static_cast(input.length()), + input.data(buffer_index), + output->mutable_data(buffer_index)); +} + +void KeyEncoder::TransformBoolean::PostDecode(const KeyColumnArray& input, + KeyColumnArray* output, + KeyEncoderContext* ctx) { + // Make sure that metadata and lengths are compatible. + ARROW_DCHECK(output->metadata().is_fixed_length == input.metadata().is_fixed_length); + ARROW_DCHECK(output->metadata().fixed_length == 0 && + input.metadata().fixed_length == 1); + ARROW_DCHECK(output->length() == input.length()); + constexpr int buffer_index = 1; + ARROW_DCHECK(input.data(buffer_index) != nullptr); + ARROW_DCHECK(output->mutable_data(buffer_index) != nullptr); + + util::BitUtil::bytes_to_bits(ctx->hardware_flags, static_cast(input.length()), + input.data(buffer_index), + output->mutable_data(buffer_index)); +} + +bool KeyEncoder::EncoderInteger::IsBoolean(const KeyColumnMetadata& metadata) { + return metadata.is_fixed_length && metadata.fixed_length == 0; +} + +bool KeyEncoder::EncoderInteger::UsesTransform(const KeyColumnArray& column) { + return IsBoolean(column.metadata()); +} + +KeyEncoder::KeyColumnArray KeyEncoder::EncoderInteger::ArrayReplace( + const KeyColumnArray& column, const KeyColumnArray& temp) { + if (IsBoolean(column.metadata())) { + return TransformBoolean::ArrayReplace(column, temp); + } + return column; +} + +void KeyEncoder::EncoderInteger::PreEncode(const KeyColumnArray& input, + KeyColumnArray* output, + KeyEncoderContext* ctx) { + if (IsBoolean(input.metadata())) { + TransformBoolean::PreEncode(input, output, ctx); + } +} + +void KeyEncoder::EncoderInteger::PostDecode(const KeyColumnArray& input, + KeyColumnArray* output, + KeyEncoderContext* ctx) { + if (IsBoolean(output->metadata())) { + TransformBoolean::PostDecode(input, output, ctx); + } +} + +void KeyEncoder::EncoderInteger::Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp) { + KeyColumnArray col_prep; + if (UsesTransform(col)) { + col_prep = ArrayReplace(col, *temp); + PreEncode(col, &col_prep, ctx); + } else { + col_prep = col; + } + + uint32_t num_rows = static_cast(col.length()); + + // When we have a single fixed length column we can just do memcpy + if (rows->metadata().is_fixed_length && + rows->metadata().fixed_length == col.metadata().fixed_length) { + ARROW_DCHECK(offset_within_row == 0); + uint32_t row_size = col.metadata().fixed_length; + memcpy(rows->mutable_data(1), col.data(1), num_rows * row_size); + } else if (rows->metadata().is_fixed_length) { + uint32_t row_size = rows->metadata().fixed_length; + uint8_t* row_base = rows->mutable_data(1) + offset_within_row; + const uint8_t* col_base = col_prep.data(1); + switch (col_prep.metadata().fixed_length) { + case 1: + for (uint32_t i = 0; i < num_rows; ++i) { + row_base[i * row_size] = col_base[i]; + } + break; + case 2: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + i * row_size) = + reinterpret_cast(col_base)[i]; + } + break; + case 4: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + i * row_size) = + reinterpret_cast(col_base)[i]; + } + break; + case 8: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + i * row_size) = + reinterpret_cast(col_base)[i]; + } + break; + default: + ARROW_DCHECK(false); + } + } else { + const uint32_t* row_offsets = rows->offsets(); + uint8_t* row_base = rows->mutable_data(2) + offset_within_row; + const uint8_t* col_base = col_prep.data(1); + switch (col_prep.metadata().fixed_length) { + case 1: + for (uint32_t i = 0; i < num_rows; ++i) { + row_base[row_offsets[i]] = col_base[i]; + } + break; + case 2: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + row_offsets[i]) = + reinterpret_cast(col_base)[i]; + } + break; + case 4: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + row_offsets[i]) = + reinterpret_cast(col_base)[i]; + } + break; + case 8: + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(row_base + row_offsets[i]) = + reinterpret_cast(col_base)[i]; + } + break; + default: + ARROW_DCHECK(false); + } + } +} + +void KeyEncoder::EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx, KeyColumnArray* temp) { + KeyColumnArray col_prep; + if (UsesTransform(*col)) { + col_prep = ArrayReplace(*col, *temp); + } else { + col_prep = *col; + } + + // When we have a single fixed length column we can just do memcpy + if (rows.metadata().is_fixed_length && + col_prep.metadata().fixed_length == rows.metadata().fixed_length) { + ARROW_DCHECK(offset_within_row == 0); + uint32_t row_size = rows.metadata().fixed_length; + memcpy(col_prep.mutable_data(1), rows.data(1) + start_row * row_size, + num_rows * row_size); + } else if (rows.metadata().is_fixed_length) { + uint32_t row_size = rows.metadata().fixed_length; + const uint8_t* row_base = rows.data(1) + start_row * row_size; + row_base += offset_within_row; + uint8_t* col_base = col_prep.mutable_data(1); + switch (col_prep.metadata().fixed_length) { + case 1: + for (uint32_t i = 0; i < num_rows; ++i) { + col_base[i] = row_base[i * row_size]; + } + break; + case 2: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + i * row_size); + } + break; + case 4: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + i * row_size); + } + break; + case 8: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + i * row_size); + } + break; + default: + ARROW_DCHECK(false); + } + } else { + const uint32_t* row_offsets = rows.offsets() + start_row; + const uint8_t* row_base = rows.data(2); + row_base += offset_within_row; + uint8_t* col_base = col_prep.mutable_data(1); + switch (col_prep.metadata().fixed_length) { + case 1: + for (uint32_t i = 0; i < num_rows; ++i) { + col_base[i] = row_base[row_offsets[i]]; + } + break; + case 2: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + row_offsets[i]); + } + break; + case 4: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + row_offsets[i]); + } + break; + case 8: + for (uint32_t i = 0; i < num_rows; ++i) { + reinterpret_cast(col_base)[i] = + *reinterpret_cast(row_base + row_offsets[i]); + } + break; + default: + ARROW_DCHECK(false); + } + } + + if (UsesTransform(*col)) { + PostDecode(col_prep, col, ctx); + } +} + +bool KeyEncoder::EncoderBinary::IsInteger(const KeyColumnMetadata& metadata) { + bool is_fixed_length = metadata.is_fixed_length; + auto size = metadata.fixed_length; + return is_fixed_length && + (size == 0 || size == 1 || size == 2 || size == 4 || size == 8); +} + +void KeyEncoder::EncoderBinary::Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp) { + if (IsInteger(col.metadata())) { + EncoderInteger::Encode(offset_within_row, rows, col, ctx, temp); + } else { + KeyColumnArray col_prep; + if (EncoderInteger::UsesTransform(col)) { + col_prep = EncoderInteger::ArrayReplace(col, *temp); + EncoderInteger::PreEncode(col, &col_prep, ctx); + } else { + col_prep = col; + } + + bool is_row_fixed_length = rows->metadata().is_fixed_length; + +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + EncodeHelper_avx2(is_row_fixed_length, offset_within_row, rows, col); + } else { +#endif + if (is_row_fixed_length) { + EncodeImp(offset_within_row, rows, col); + } else { + EncodeImp(offset_within_row, rows, col); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif + } + + ARROW_DCHECK(temp->metadata().is_fixed_length); + ARROW_DCHECK(temp->length() * temp->metadata().fixed_length >= + col.length() * static_cast(sizeof(uint16_t))); + + KeyColumnArray temp16bit(KeyColumnMetadata(true, sizeof(uint16_t)), col.length(), + nullptr, temp->mutable_data(1), nullptr); + ColumnMemsetNulls(offset_within_row, rows, col, ctx, &temp16bit, 0xae); +} + +void KeyEncoder::EncoderBinary::Decode(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx, KeyColumnArray* temp) { + if (IsInteger(col->metadata())) { + EncoderInteger::Decode(start_row, num_rows, offset_within_row, rows, col, ctx, temp); + } else { + KeyColumnArray col_prep; + if (EncoderInteger::UsesTransform(*col)) { + col_prep = EncoderInteger::ArrayReplace(*col, *temp); + } else { + col_prep = *col; + } + + bool is_row_fixed_length = rows.metadata().is_fixed_length; + +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + DecodeHelper_avx2(is_row_fixed_length, start_row, num_rows, offset_within_row, rows, + col); + } else { +#endif + if (is_row_fixed_length) { + DecodeImp(start_row, num_rows, offset_within_row, rows, col); + } else { + DecodeImp(start_row, num_rows, offset_within_row, rows, col); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif + + if (EncoderInteger::UsesTransform(*col)) { + EncoderInteger::PostDecode(col_prep, col, ctx); + } + } +} + +template +void KeyEncoder::EncoderBinary::EncodeImp(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col) { + EncodeDecodeHelper( + 0, static_cast(col.length()), offset_within_row, rows, rows, &col, + nullptr, [](uint8_t* dst, const uint8_t* src, int64_t length) { + uint64_t* dst64 = reinterpret_cast(dst); + const uint64_t* src64 = reinterpret_cast(src); + uint32_t istripe; + for (istripe = 0; istripe < length / 8; ++istripe) { + dst64[istripe] = util::SafeLoad(src64 + istripe); + } + if ((length % 8) > 0) { + uint64_t mask_last = ~0ULL >> (8 * (8 * (istripe + 1) - length)); + dst64[istripe] = (dst64[istripe] & ~mask_last) | + (util::SafeLoad(src64 + istripe) & mask_last); + } + }); +} + +template +void KeyEncoder::EncoderBinary::DecodeImp(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col) { + EncodeDecodeHelper( + start_row, num_rows, offset_within_row, &rows, nullptr, col, col, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + for (uint32_t istripe = 0; istripe < (length + 7) / 8; ++istripe) { + uint64_t* dst64 = reinterpret_cast(dst); + const uint64_t* src64 = reinterpret_cast(src); + util::SafeStore(dst64 + istripe, src64[istripe]); + } + }); +} + +void KeyEncoder::EncoderBinary::ColumnMemsetNulls( + uint32_t offset_within_row, KeyRowArray* rows, const KeyColumnArray& col, + KeyEncoderContext* ctx, KeyColumnArray* temp_vector_16bit, uint8_t byte_value) { + typedef void (*ColumnMemsetNullsImp_t)(uint32_t, KeyRowArray*, const KeyColumnArray&, + KeyEncoderContext*, KeyColumnArray*, uint8_t); + static const ColumnMemsetNullsImp_t ColumnMemsetNullsImp_fn[] = { + ColumnMemsetNullsImp, ColumnMemsetNullsImp, + ColumnMemsetNullsImp, ColumnMemsetNullsImp, + ColumnMemsetNullsImp, ColumnMemsetNullsImp, + ColumnMemsetNullsImp, ColumnMemsetNullsImp, + ColumnMemsetNullsImp, ColumnMemsetNullsImp}; + uint32_t col_width = col.metadata().fixed_length; + int dispatch_const = + (rows->metadata().is_fixed_length ? 5 : 0) + + (col_width == 1 ? 0 + : col_width == 2 ? 1 : col_width == 4 ? 2 : col_width == 8 ? 3 : 4); + ColumnMemsetNullsImp_fn[dispatch_const](offset_within_row, rows, col, ctx, + temp_vector_16bit, byte_value); +} + +template +void KeyEncoder::EncoderBinary::ColumnMemsetNullsImp( + uint32_t offset_within_row, KeyRowArray* rows, const KeyColumnArray& col, + KeyEncoderContext* ctx, KeyColumnArray* temp_vector_16bit, uint8_t byte_value) { + // Nothing to do when there are no nulls + if (!col.data(0)) { + return; + } + + uint32_t num_rows = static_cast(col.length()); + + // Temp vector needs space for the required number of rows + ARROW_DCHECK(temp_vector_16bit->length() >= num_rows); + ARROW_DCHECK(temp_vector_16bit->metadata().is_fixed_length && + temp_vector_16bit->metadata().fixed_length == sizeof(uint16_t)); + uint16_t* temp_vector = reinterpret_cast(temp_vector_16bit->mutable_data(1)); + + // Bit vector to index vector of null positions + int num_selected; + util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, static_cast(col.length()), + col.data(0), &num_selected, temp_vector); + + for (int i = 0; i < num_selected; ++i) { + uint32_t row_id = temp_vector[i]; + + // Target binary field pointer + uint8_t* dst; + if (is_row_fixed_length) { + dst = rows->mutable_data(1) + rows->metadata().fixed_length * row_id; + } else { + dst = rows->mutable_data(2) + rows->offsets()[row_id]; + } + dst += offset_within_row; + + if (col_width == 1) { + *dst = byte_value; + } else if (col_width == 2) { + *reinterpret_cast(dst) = + (static_cast(byte_value) * static_cast(0x0101)); + } else if (col_width == 4) { + *reinterpret_cast(dst) = + (static_cast(byte_value) * static_cast(0x01010101)); + } else if (col_width == 8) { + *reinterpret_cast(dst) = + (static_cast(byte_value) * 0x0101010101010101ULL); + } else { + uint64_t value = (static_cast(byte_value) * 0x0101010101010101ULL); + uint32_t col_width_actual = col.metadata().fixed_length; + uint32_t j; + for (j = 0; j < col_width_actual / 8; ++j) { + reinterpret_cast(dst)[j] = value; + } + int tail = col_width_actual % 8; + if (tail) { + uint64_t mask = ~0ULL >> (8 * (8 - tail)); + reinterpret_cast(dst)[j] = + (reinterpret_cast(dst)[j] & ~mask) | (value & mask); + } + } + } +} + +void KeyEncoder::EncoderBinaryPair::Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col1, + const KeyColumnArray& col2, + KeyEncoderContext* ctx, KeyColumnArray* temp1, + KeyColumnArray* temp2) { + ARROW_DCHECK(CanProcessPair(col1.metadata(), col2.metadata())); + + KeyColumnArray col_prep[2]; + if (EncoderInteger::UsesTransform(col1)) { + col_prep[0] = EncoderInteger::ArrayReplace(col1, *temp1); + EncoderInteger::PreEncode(col1, &(col_prep[0]), ctx); + } else { + col_prep[0] = col1; + } + if (EncoderInteger::UsesTransform(col2)) { + col_prep[1] = EncoderInteger::ArrayReplace(col2, *temp2); + EncoderInteger::PreEncode(col2, &(col_prep[1]), ctx); + } else { + col_prep[1] = col2; + } + + uint32_t col_width1 = col_prep[0].metadata().fixed_length; + uint32_t col_width2 = col_prep[1].metadata().fixed_length; + int log_col_width1 = + col_width1 == 8 ? 3 : col_width1 == 4 ? 2 : col_width1 == 2 ? 1 : 0; + int log_col_width2 = + col_width2 == 8 ? 3 : col_width2 == 4 ? 2 : col_width2 == 2 ? 1 : 0; + + bool is_row_fixed_length = rows->metadata().is_fixed_length; + + uint32_t num_rows = static_cast(col1.length()); + uint32_t num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2() && col_width1 == col_width2) { + num_processed = EncodeHelper_avx2(is_row_fixed_length, col_width1, offset_within_row, + rows, col_prep[0], col_prep[1]); + } +#endif + if (num_processed < num_rows) { + using EncodeImp_t = void (*)(uint32_t, uint32_t, KeyRowArray*, const KeyColumnArray&, + const KeyColumnArray&); + static const EncodeImp_t EncodeImp_fn[] = { + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp, + EncodeImp, EncodeImp}; + int dispatch_const = (log_col_width2 << 2) | log_col_width1; + dispatch_const += (is_row_fixed_length ? 16 : 0); + EncodeImp_fn[dispatch_const](num_processed, offset_within_row, rows, col_prep[0], + col_prep[1]); + } +} + +template +void KeyEncoder::EncoderBinaryPair::EncodeImp(uint32_t num_rows_to_skip, + uint32_t offset_within_row, + KeyRowArray* rows, + const KeyColumnArray& col1, + const KeyColumnArray& col2) { + const uint8_t* src_A = col1.data(1); + const uint8_t* src_B = col2.data(1); + + uint32_t num_rows = static_cast(col1.length()); + + uint32_t fixed_length = rows->metadata().fixed_length; + const uint32_t* offsets; + uint8_t* dst_base; + if (is_row_fixed_length) { + dst_base = rows->mutable_data(1) + offset_within_row; + offsets = nullptr; + } else { + dst_base = rows->mutable_data(2) + offset_within_row; + offsets = rows->offsets(); + } + + using col1_type_const = typename std::add_const::type; + using col2_type_const = typename std::add_const::type; + + if (is_row_fixed_length) { + uint8_t* dst = dst_base + num_rows_to_skip * fixed_length; + for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) { + *reinterpret_cast(dst) = reinterpret_cast(src_A)[i]; + *reinterpret_cast(dst + sizeof(col1_type)) = + reinterpret_cast(src_B)[i]; + dst += fixed_length; + } + } else { + for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) { + uint8_t* dst = dst_base + offsets[i]; + *reinterpret_cast(dst) = reinterpret_cast(src_A)[i]; + *reinterpret_cast(dst + sizeof(col1_type)) = + reinterpret_cast(src_B)[i]; + } + } +} + +void KeyEncoder::EncoderBinaryPair::Decode(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col1, + KeyColumnArray* col2, KeyEncoderContext* ctx, + KeyColumnArray* temp1, KeyColumnArray* temp2) { + ARROW_DCHECK(CanProcessPair(col1->metadata(), col2->metadata())); + + KeyColumnArray col_prep[2]; + if (EncoderInteger::UsesTransform(*col1)) { + col_prep[0] = EncoderInteger::ArrayReplace(*col1, *temp1); + } else { + col_prep[0] = *col1; + } + if (EncoderInteger::UsesTransform(*col2)) { + col_prep[1] = EncoderInteger::ArrayReplace(*col2, *temp2); + } else { + col_prep[1] = *col2; + } + + uint32_t col_width1 = col_prep[0].metadata().fixed_length; + uint32_t col_width2 = col_prep[1].metadata().fixed_length; + int log_col_width1 = + col_width1 == 8 ? 3 : col_width1 == 4 ? 2 : col_width1 == 2 ? 1 : 0; + int log_col_width2 = + col_width2 == 8 ? 3 : col_width2 == 4 ? 2 : col_width2 == 2 ? 1 : 0; + + bool is_row_fixed_length = rows.metadata().is_fixed_length; + + uint32_t num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2() && col_width1 == col_width2) { + num_processed = + DecodeHelper_avx2(is_row_fixed_length, col_width1, start_row, num_rows, + offset_within_row, rows, &col_prep[0], &col_prep[1]); + } +#endif + if (num_processed < num_rows) { + typedef void (*DecodeImp_t)(uint32_t, uint32_t, uint32_t, uint32_t, + const KeyRowArray&, KeyColumnArray*, KeyColumnArray*); + static const DecodeImp_t DecodeImp_fn[] = { + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp, + DecodeImp, DecodeImp}; + int dispatch_const = + (log_col_width2 << 2) | log_col_width1 | (is_row_fixed_length ? 16 : 0); + DecodeImp_fn[dispatch_const](num_processed, start_row, num_rows, offset_within_row, + rows, &(col_prep[0]), &(col_prep[1])); + } + + if (EncoderInteger::UsesTransform(*col1)) { + EncoderInteger::PostDecode(col_prep[0], col1, ctx); + } + if (EncoderInteger::UsesTransform(*col2)) { + EncoderInteger::PostDecode(col_prep[1], col2, ctx); + } +} + +template +void KeyEncoder::EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip, + uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, + KeyColumnArray* col1, + KeyColumnArray* col2) { + ARROW_DCHECK(rows.length() >= start_row + num_rows); + ARROW_DCHECK(col1->length() == num_rows && col2->length() == num_rows); + + uint8_t* dst_A = col1->mutable_data(1); + uint8_t* dst_B = col2->mutable_data(1); + + uint32_t fixed_length = rows.metadata().fixed_length; + const uint32_t* offsets; + const uint8_t* src_base; + if (is_row_fixed_length) { + src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + offsets = nullptr; + } else { + src_base = rows.data(2) + offset_within_row; + offsets = rows.offsets() + start_row; + } + + using col1_type_const = typename std::add_const::type; + using col2_type_const = typename std::add_const::type; + + if (is_row_fixed_length) { + const uint8_t* src = src_base + num_rows_to_skip * fixed_length; + for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) { + reinterpret_cast(dst_A)[i] = *reinterpret_cast(src); + reinterpret_cast(dst_B)[i] = + *reinterpret_cast(src + sizeof(col1_type)); + src += fixed_length; + } + } else { + for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) { + const uint8_t* src = src_base + offsets[i]; + reinterpret_cast(dst_A)[i] = *reinterpret_cast(src); + reinterpret_cast(dst_B)[i] = + *reinterpret_cast(src + sizeof(col1_type)); + } + } +} + +void KeyEncoder::EncoderOffsets::Encode(KeyRowArray* rows, + const std::vector& varbinary_cols, + KeyEncoderContext* ctx) { + ARROW_DCHECK(!varbinary_cols.empty()); + + // Rows and columns must all be varying-length + ARROW_DCHECK(!rows->metadata().is_fixed_length); + for (size_t col = 0; col < varbinary_cols.size(); ++col) { + ARROW_DCHECK(!varbinary_cols[col].metadata().is_fixed_length); + } + + uint32_t num_rows = static_cast(varbinary_cols[0].length()); + + // The space in columns must be exactly equal to a space for offsets in rows + ARROW_DCHECK(rows->length() == num_rows); + for (size_t col = 0; col < varbinary_cols.size(); ++col) { + ARROW_DCHECK(varbinary_cols[col].length() == num_rows); + } + + uint32_t num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + // Create a temp vector sized based on the number of columns + auto temp_buffer_holder = util::TempVectorHolder( + ctx->stack, static_cast(varbinary_cols.size()) * 8); + auto temp_buffer_32B_per_col = KeyColumnArray( + KeyColumnMetadata(true, sizeof(uint32_t)), varbinary_cols.size() * 8, nullptr, + reinterpret_cast(temp_buffer_holder.mutable_data()), nullptr); + + num_processed = EncodeImp_avx2(rows, varbinary_cols, &temp_buffer_32B_per_col); + } +#endif + if (num_processed < num_rows) { + EncodeImp(num_processed, rows, varbinary_cols); + } +} + +void KeyEncoder::EncoderOffsets::EncodeImp( + uint32_t num_rows_already_processed, KeyRowArray* rows, + const std::vector& varbinary_cols) { + ARROW_DCHECK(varbinary_cols.size() > 0); + + int row_alignment = rows->metadata().row_alignment; + int string_alignment = rows->metadata().string_alignment; + + uint32_t* row_offsets = rows->mutable_offsets(); + uint8_t* row_values = rows->mutable_data(2); + uint32_t num_rows = static_cast(varbinary_cols[0].length()); + + if (num_rows_already_processed == 0) { + row_offsets[0] = 0; + } + + uint32_t row_offset = row_offsets[num_rows_already_processed]; + for (uint32_t i = num_rows_already_processed; i < num_rows; ++i) { + uint32_t* varbinary_end = + rows->metadata().varbinary_end_array(row_values + row_offset); + + // Zero out lengths for nulls. + // Add lengths of all columns to get row size. + // Store varbinary field ends while summing their lengths. + + uint32_t offset_within_row = rows->metadata().fixed_length; + + for (size_t col = 0; col < varbinary_cols.size(); ++col) { + const uint32_t* col_offsets = varbinary_cols[col].offsets(); + uint32_t col_length = col_offsets[i + 1] - col_offsets[i]; + + const uint8_t* non_nulls = varbinary_cols[col].data(0); + if (non_nulls && BitUtil::GetBit(non_nulls, i) == 0) { + col_length = 0; + } + + offset_within_row += + KeyRowMetadata::padding_for_alignment(offset_within_row, string_alignment); + offset_within_row += col_length; + + varbinary_end[col] = offset_within_row; + } + + offset_within_row += + KeyRowMetadata::padding_for_alignment(offset_within_row, row_alignment); + row_offset += offset_within_row; + row_offsets[i + 1] = row_offset; + } +} + +void KeyEncoder::EncoderOffsets::Decode( + uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows, + std::vector* varbinary_cols, + const std::vector& varbinary_cols_base_offset, KeyEncoderContext* ctx) { + ARROW_DCHECK(!varbinary_cols->empty()); + ARROW_DCHECK(varbinary_cols->size() == varbinary_cols_base_offset.size()); + + // Rows and columns must all be varying-length + ARROW_DCHECK(!rows.metadata().is_fixed_length); + for (size_t col = 0; col < varbinary_cols->size(); ++col) { + ARROW_DCHECK(!(*varbinary_cols)[col].metadata().is_fixed_length); + } + + // The space in columns must be exactly equal to a subset of rows selected + ARROW_DCHECK(rows.length() >= start_row + num_rows); + for (size_t col = 0; col < varbinary_cols->size(); ++col) { + ARROW_DCHECK((*varbinary_cols)[col].length() == num_rows); + } + + // Offsets of varbinary columns data within each encoded row are stored + // in the same encoded row as an array of 32-bit integers. + // This array follows immediately the data of fixed-length columns. + // There is one element for each varying-length column. + // The Nth element is the sum of all the lengths of varbinary columns data in + // that row, up to and including Nth varbinary column. + + const uint32_t* row_offsets = rows.offsets() + start_row; + + // Set the base offset for each column + for (size_t col = 0; col < varbinary_cols->size(); ++col) { + uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets(); + col_offsets[0] = varbinary_cols_base_offset[col]; + } + + int string_alignment = rows.metadata().string_alignment; + + for (uint32_t i = 0; i < num_rows; ++i) { + // Find the beginning of cumulative lengths array for next row + const uint8_t* row = rows.data(2) + row_offsets[i]; + const uint32_t* varbinary_ends = rows.metadata().varbinary_end_array(row); + + // Update the offset of each column + uint32_t offset_within_row = rows.metadata().fixed_length; + for (size_t col = 0; col < varbinary_cols->size(); ++col) { + offset_within_row += + KeyRowMetadata::padding_for_alignment(offset_within_row, string_alignment); + uint32_t length = varbinary_ends[col] - offset_within_row; + offset_within_row = varbinary_ends[col]; + uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets(); + col_offsets[i + 1] = col_offsets[i] + length; + } + } +} + +void KeyEncoder::EncoderVarBinary::Encode(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col, + KeyEncoderContext* ctx) { +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + EncodeHelper_avx2(varbinary_col_id, rows, col); + } else { +#endif + if (varbinary_col_id == 0) { + EncodeImp(varbinary_col_id, rows, col); + } else { + EncodeImp(varbinary_col_id, rows, col); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +void KeyEncoder::EncoderVarBinary::Decode(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx) { + // Output column varbinary buffer needs an extra 32B + // at the end in avx2 version and 8B otherwise. +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + DecodeHelper_avx2(start_row, num_rows, varbinary_col_id, rows, col); + } else { +#endif + if (varbinary_col_id == 0) { + DecodeImp(start_row, num_rows, varbinary_col_id, rows, col); + } else { + DecodeImp(start_row, num_rows, varbinary_col_id, rows, col); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +template +void KeyEncoder::EncoderVarBinary::EncodeImp(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col) { + EncodeDecodeHelper( + 0, static_cast(col.length()), varbinary_col_id, rows, rows, &col, nullptr, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + uint64_t* dst64 = reinterpret_cast(dst); + const uint64_t* src64 = reinterpret_cast(src); + uint32_t istripe; + for (istripe = 0; istripe < length / 8; ++istripe) { + dst64[istripe] = util::SafeLoad(src64 + istripe); + } + if ((length % 8) > 0) { + uint64_t mask_last = ~0ULL >> (8 * (8 * (istripe + 1) - length)); + dst64[istripe] = (dst64[istripe] & ~mask_last) | + (util::SafeLoad(src64 + istripe) & mask_last); + } + }); +} + +template +void KeyEncoder::EncoderVarBinary::DecodeImp(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, + const KeyRowArray& rows, + KeyColumnArray* col) { + EncodeDecodeHelper( + start_row, num_rows, varbinary_col_id, &rows, nullptr, col, col, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + for (uint32_t istripe = 0; istripe < (length + 7) / 8; ++istripe) { + uint64_t* dst64 = reinterpret_cast(dst); + const uint64_t* src64 = reinterpret_cast(src); + util::SafeStore(dst64 + istripe, src64[istripe]); + } + }); +} + +void KeyEncoder::EncoderNulls::Encode(KeyRowArray* rows, + const std::vector& cols, + KeyEncoderContext* ctx, + KeyColumnArray* temp_vector_16bit) { + ARROW_DCHECK(cols.size() > 0); + uint32_t num_rows = static_cast(rows->length()); + + // All input columns should have the same number of rows. + // They may or may not have non-nulls bit-vectors allocated. + for (size_t col = 0; col < cols.size(); ++col) { + ARROW_DCHECK(cols[col].length() == num_rows); + } + + // Temp vector needs space for the required number of rows + ARROW_DCHECK(temp_vector_16bit->length() >= num_rows); + ARROW_DCHECK(temp_vector_16bit->metadata().is_fixed_length && + temp_vector_16bit->metadata().fixed_length == sizeof(uint16_t)); + + uint8_t* null_masks = rows->null_masks(); + uint32_t null_masks_bytes_per_row = rows->metadata().null_masks_bytes_per_row; + memset(null_masks, 0, null_masks_bytes_per_row * num_rows); + for (size_t col = 0; col < cols.size(); ++col) { + const uint8_t* non_nulls = cols[col].data(0); + if (!non_nulls) { + continue; + } + int num_selected; + util::BitUtil::bits_to_indexes( + 0, ctx->hardware_flags, num_rows, non_nulls, &num_selected, + reinterpret_cast(temp_vector_16bit->mutable_data(1))); + for (int i = 0; i < num_selected; ++i) { + uint16_t row_id = reinterpret_cast(temp_vector_16bit->data(1))[i]; + int64_t null_masks_bit_id = row_id * null_masks_bytes_per_row * 8 + col; + BitUtil::SetBit(null_masks, null_masks_bit_id); + } + } +} + +void KeyEncoder::EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows, + const KeyRowArray& rows, + std::vector* cols) { + // Every output column needs to have a space for exactly the required number + // of rows. It also needs to have non-nulls bit-vector allocated and mutable. + ARROW_DCHECK(cols->size() > 0); + for (size_t col = 0; col < cols->size(); ++col) { + ARROW_DCHECK((*cols)[col].length() == num_rows); + ARROW_DCHECK((*cols)[col].mutable_data(0)); + } + + const uint8_t* null_masks = rows.null_masks(); + uint32_t null_masks_bytes_per_row = rows.metadata().null_masks_bytes_per_row; + for (size_t col = 0; col < cols->size(); ++col) { + uint8_t* non_nulls = (*cols)[col].mutable_data(0); + memset(non_nulls, 0xff, BitUtil::BytesForBits(num_rows)); + for (uint32_t row = 0; row < num_rows; ++row) { + uint32_t null_masks_bit_id = + (start_row + row) * null_masks_bytes_per_row * 8 + static_cast(col); + bool is_set = BitUtil::GetBit(null_masks, null_masks_bit_id); + if (is_set) { + BitUtil::ClearBit(non_nulls, row); + } + } + } +} + +uint32_t KeyEncoder::KeyRowMetadata::num_varbinary_cols() const { + uint32_t result = 0; + for (size_t i = 0; i < column_metadatas.size(); ++i) { + if (!column_metadatas[i].is_fixed_length) { + ++result; + } + } + return result; +} + +bool KeyEncoder::KeyRowMetadata::is_compatible(const KeyRowMetadata& other) const { + if (other.num_cols() != num_cols()) { + return false; + } + if (row_alignment != other.row_alignment || + string_alignment != other.string_alignment) { + return false; + } + for (size_t i = 0; i < column_metadatas.size(); ++i) { + if (column_metadatas[i].is_fixed_length != + other.column_metadatas[i].is_fixed_length) { + return false; + } + if (column_metadatas[i].fixed_length != other.column_metadatas[i].fixed_length) { + return false; + } + } + return true; +} + +void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( + const std::vector& cols, int in_row_alignment, + int in_string_alignment) { + column_metadatas.resize(cols.size()); + for (size_t i = 0; i < cols.size(); ++i) { + column_metadatas[i] = cols[i]; + } + + uint32_t num_cols = static_cast(cols.size()); + + // Sort columns. + // Columns are sorted based on the size in bytes of their fixed-length part. + // For the varying-length column, the fixed-length part is the 32-bit field storing + // cumulative length of varying-length fields. + // The rules are: + // a) Boolean column, marked with fixed-length 0, is considered to have fixed-length + // part of 1 byte. b) Columns with fixed-length part being power of 2 or multiple of row + // alignment precede other columns. They are sorted among themselves based on size of + // fixed-length part. c) Fixed-length columns precede varying-length columns when both + // have the same size fixed-length part. + column_order.resize(num_cols); + for (uint32_t i = 0; i < num_cols; ++i) { + column_order[i] = i; + } + std::sort( + column_order.begin(), column_order.end(), [&cols](uint32_t left, uint32_t right) { + bool is_left_pow2 = + !cols[left].is_fixed_length || ARROW_POPCOUNT64(cols[left].fixed_length) <= 1; + bool is_right_pow2 = !cols[right].is_fixed_length || + ARROW_POPCOUNT64(cols[right].fixed_length) <= 1; + bool is_left_fixedlen = cols[left].is_fixed_length; + bool is_right_fixedlen = cols[right].is_fixed_length; + uint32_t width_left = + cols[left].is_fixed_length ? cols[left].fixed_length : sizeof(uint32_t); + uint32_t width_right = + cols[right].is_fixed_length ? cols[right].fixed_length : sizeof(uint32_t); + if (is_left_pow2 != is_right_pow2) { + return is_left_pow2; + } + if (!is_left_pow2) { + return left < right; + } + if (width_left != width_right) { + return width_left > width_right; + } + if (is_left_fixedlen != is_right_fixedlen) { + return is_left_fixedlen; + } + return left < right; + }); + + row_alignment = in_row_alignment; + string_alignment = in_string_alignment; + varbinary_end_array_offset = 0; + + column_offsets.resize(num_cols); + uint32_t num_varbinary_cols = 0; + uint32_t offset_within_row = 0; + for (uint32_t i = 0; i < num_cols; ++i) { + const KeyColumnMetadata& col = cols[column_order[i]]; + offset_within_row += + KeyRowMetadata::padding_for_alignment(offset_within_row, string_alignment, col); + column_offsets[i] = offset_within_row; + if (!col.is_fixed_length) { + if (num_varbinary_cols == 0) { + varbinary_end_array_offset = offset_within_row; + } + ARROW_DCHECK(column_offsets[i] - varbinary_end_array_offset == + num_varbinary_cols * sizeof(uint32_t)); + ++num_varbinary_cols; + offset_within_row += sizeof(uint32_t); + } else { + // Boolean column is a bit-vector, which is indicated by + // setting fixed length in column metadata to zero. + // It will be stored as a byte in output row. + if (col.fixed_length == 0) { + offset_within_row += 1; + } else { + offset_within_row += col.fixed_length; + } + } + } + + is_fixed_length = (num_varbinary_cols == 0); + fixed_length = + offset_within_row + + KeyRowMetadata::padding_for_alignment( + offset_within_row, num_varbinary_cols == 0 ? row_alignment : string_alignment); + + // We set the number of bytes per row storing null masks of individual key columns + // to be a power of two. This is not required. It could be also set to the minimal + // number of bytes required for a given number of bits (one bit per column). + null_masks_bytes_per_row = 1; + while (static_cast(null_masks_bytes_per_row * 8) < num_cols) { + null_masks_bytes_per_row *= 2; + } +} + +void KeyEncoder::Init(const std::vector& cols, KeyEncoderContext* ctx, + int row_alignment, int string_alignment) { + ctx_ = ctx; + row_metadata_.FromColumnMetadataVector(cols, row_alignment, string_alignment); + uint32_t num_cols = row_metadata_.num_cols(); + uint32_t num_varbinary_cols = row_metadata_.num_varbinary_cols(); + batch_all_cols_.resize(num_cols); + batch_varbinary_cols_.resize(num_varbinary_cols); + batch_varbinary_cols_base_offsets_.resize(num_varbinary_cols); +} + +void KeyEncoder::PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows, + const std::vector& cols_in) { + uint32_t num_cols = static_cast(cols_in.size()); + ARROW_DCHECK(batch_all_cols_.size() == num_cols); + + uint32_t num_varbinary_visited = 0; + for (uint32_t i = 0; i < num_cols; ++i) { + const KeyColumnArray& col = cols_in[row_metadata_.column_order[i]]; + KeyColumnArray col_window(col, start_row, num_rows); + batch_all_cols_[i] = col_window; + if (!col.metadata().is_fixed_length) { + ARROW_DCHECK(num_varbinary_visited < batch_varbinary_cols_.size()); + // If start row is zero, then base offset of varbinary column is also zero. + if (start_row == 0) { + batch_varbinary_cols_base_offsets_[num_varbinary_visited] = 0; + } else { + batch_varbinary_cols_base_offsets_[num_varbinary_visited] = + col.offsets()[start_row]; + } + batch_varbinary_cols_[num_varbinary_visited++] = col_window; + } + } +} + +Status KeyEncoder::PrepareOutputForEncode(int64_t start_row, int64_t num_rows, + KeyRowArray* rows, + const std::vector& all_cols) { + int64_t num_bytes_required = 0; + + int64_t fixed_part = row_metadata_.fixed_length * num_rows; + int64_t var_part = 0; + for (size_t i = 0; i < all_cols.size(); ++i) { + const KeyColumnArray& col = all_cols[i]; + if (!col.metadata().is_fixed_length) { + ARROW_DCHECK(col.length() >= start_row + num_rows); + const uint32_t* offsets = col.offsets(); + var_part += offsets[start_row + num_rows] - offsets[start_row]; + // Include maximum padding that can be added to align the start of varbinary fields. + var_part += num_rows * row_metadata_.string_alignment; + } + } + // Include maximum padding that can be added to align the start of the rows. + if (!row_metadata_.is_fixed_length) { + fixed_part += row_metadata_.row_alignment * num_rows; + } + num_bytes_required = fixed_part + var_part; + + rows->Clean(); + RETURN_NOT_OK(rows->AppendEmpty(static_cast(num_rows), + static_cast(num_bytes_required))); + + return Status::OK(); +} + +void KeyEncoder::Encode(int64_t start_row, int64_t num_rows, KeyRowArray* rows, + const std::vector& cols) { + // Prepare column array vectors + PrepareKeyColumnArrays(start_row, num_rows, cols); + + // Create two temp vectors with 16-bit elements + auto temp_buffer_holder_A = + util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + auto temp_buffer_A = KeyColumnArray( + KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, + reinterpret_cast(temp_buffer_holder_A.mutable_data()), nullptr); + auto temp_buffer_holder_B = + util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + auto temp_buffer_B = KeyColumnArray( + KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, + reinterpret_cast(temp_buffer_holder_B.mutable_data()), nullptr); + + bool is_row_fixed_length = row_metadata_.is_fixed_length; + if (!is_row_fixed_length) { + // This call will generate and fill in data for both: + // - offsets to the entire encoded arrays + // - offsets for individual varbinary fields within each row + EncoderOffsets::Encode(rows, batch_varbinary_cols_, ctx_); + + uint32_t num_varbinary_cols = static_cast(batch_varbinary_cols_.size()); + for (uint32_t i = 0; i < num_varbinary_cols; ++i) { + // Memcpy varbinary fields into precomputed in the previous step + // positions in the output row buffer. + EncoderVarBinary::Encode(i, rows, batch_varbinary_cols_[i], ctx_); + } + } + + // Process fixed length columns + uint32_t num_cols = static_cast(batch_all_cols_.size()); + for (uint32_t i = 0; i < num_cols;) { + if (!batch_all_cols_[i].metadata().is_fixed_length) { + i += 1; + continue; + } + bool can_process_pair = + (i + 1 < num_cols) && batch_all_cols_[i + 1].metadata().is_fixed_length && + EncoderBinaryPair::CanProcessPair(batch_all_cols_[i].metadata(), + batch_all_cols_[i + 1].metadata()); + if (!can_process_pair) { + EncoderBinary::Encode(row_metadata_.column_offsets[i], rows, batch_all_cols_[i], + ctx_, &temp_buffer_A); + i += 1; + } else { + EncoderBinaryPair::Encode(row_metadata_.column_offsets[i], rows, batch_all_cols_[i], + batch_all_cols_[i + 1], ctx_, &temp_buffer_A, + &temp_buffer_B); + i += 2; + } + } + + // Process nulls + EncoderNulls::Encode(rows, batch_all_cols_, ctx_, &temp_buffer_A); +} + +void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, + int64_t start_row_output, int64_t num_rows, + const KeyRowArray& rows, + std::vector* cols) { + // Prepare column array vectors + PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + + // Create two temp vectors with 16-bit elements + auto temp_buffer_holder_A = + util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + auto temp_buffer_A = KeyColumnArray( + KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, + reinterpret_cast(temp_buffer_holder_A.mutable_data()), nullptr); + auto temp_buffer_holder_B = + util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + auto temp_buffer_B = KeyColumnArray( + KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, + reinterpret_cast(temp_buffer_holder_B.mutable_data()), nullptr); + + bool is_row_fixed_length = row_metadata_.is_fixed_length; + if (!is_row_fixed_length) { + EncoderOffsets::Decode(static_cast(start_row_input), + static_cast(num_rows), rows, &batch_varbinary_cols_, + batch_varbinary_cols_base_offsets_, ctx_); + } + + // Process fixed length columns + uint32_t num_cols = static_cast(batch_all_cols_.size()); + for (uint32_t i = 0; i < num_cols;) { + if (!batch_all_cols_[i].metadata().is_fixed_length) { + i += 1; + continue; + } + bool can_process_pair = + (i + 1 < num_cols) && batch_all_cols_[i + 1].metadata().is_fixed_length && + EncoderBinaryPair::CanProcessPair(batch_all_cols_[i].metadata(), + batch_all_cols_[i + 1].metadata()); + if (!can_process_pair) { + EncoderBinary::Decode(static_cast(start_row_input), + static_cast(num_rows), + row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], + ctx_, &temp_buffer_A); + i += 1; + } else { + EncoderBinaryPair::Decode( + static_cast(start_row_input), static_cast(num_rows), + row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], + &batch_all_cols_[i + 1], ctx_, &temp_buffer_A, &temp_buffer_B); + i += 2; + } + } + + // Process nulls + EncoderNulls::Decode(static_cast(start_row_input), + static_cast(num_rows), rows, &batch_all_cols_); +} + +void KeyEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input, + int64_t start_row_output, int64_t num_rows, + const KeyRowArray& rows, + std::vector* cols) { + // Prepare column array vectors + PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + + bool is_row_fixed_length = row_metadata_.is_fixed_length; + if (!is_row_fixed_length) { + uint32_t num_varbinary_cols = static_cast(batch_varbinary_cols_.size()); + for (uint32_t i = 0; i < num_varbinary_cols; ++i) { + // Memcpy varbinary fields into precomputed in the previous step + // positions in the output row buffer. + EncoderVarBinary::Decode(static_cast(start_row_input), + static_cast(num_rows), i, rows, + &batch_varbinary_cols_[i], ctx_); + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_encode.h b/cpp/src/arrow/compute/exec/key_encode.h new file mode 100644 index 00000000000..3f5ef365a08 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_encode.h @@ -0,0 +1,627 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +class KeyColumnMetadata; + +/// Converts between key representation as a collection of arrays for +/// individual columns and another representation as a single array of rows +/// combining data from all columns into one value. +/// This conversion is reversible. +/// Row-oriented storage is beneficial when there is a need for random access +/// of individual rows and at the same time all included columns are likely to +/// be accessed together, as in the case of hash table key. +class KeyEncoder { + public: + struct KeyEncoderContext { + bool has_avx2() const { + return (hardware_flags & arrow::internal::CpuInfo::AVX2) > 0; + } + int64_t hardware_flags; + util::TempVectorStack* stack; + }; + + /// Description of a storage format of a single key column as needed + /// for the purpose of row encoding. + struct KeyColumnMetadata { + KeyColumnMetadata() = default; + KeyColumnMetadata(bool is_fixed_length_in, uint32_t fixed_length_in) + : is_fixed_length(is_fixed_length_in), fixed_length(fixed_length_in) {} + /// Is column storing a varying-length binary, using offsets array + /// to find a beginning of a value, or is it a fixed-length binary. + bool is_fixed_length; + /// For a fixed-length binary column: number of bytes per value. + /// Zero has a special meaning, indicating a bit vector with one bit per value. + /// For a varying-length binary column: number of bytes per offset. + uint32_t fixed_length; + }; + + /// Description of a storage format for rows produced by encoder. + struct KeyRowMetadata { + /// Is row a varying-length binary, using offsets array to find a beginning of a row, + /// or is it a fixed-length binary. + bool is_fixed_length; + + /// For a fixed-length binary row, common size of rows in bytes, + /// rounded up to the multiple of alignment. + /// + /// For a varying-length binary, size of all encoded fixed-length key columns, + /// including lengths of varying-length columns, rounded up to the multiple of string + /// alignment. + uint32_t fixed_length; + + /// Offset within a row to the array of 32-bit offsets within a row of + /// ends of varbinary fields. + /// Used only when the row is not fixed-length, zero for fixed-length row. + /// There are N elements for N varbinary fields. + /// Each element is the offset within a row of the first byte after + /// the corresponding varbinary field bytes in that row. + /// If varbinary fields begin at aligned addresses, than the end of the previous + /// varbinary field needs to be rounded up according to the specified alignment + /// to obtain the beginning of the next varbinary field. + /// The first varbinary field starts at offset specified by fixed_length, + /// which should already be aligned. + uint32_t varbinary_end_array_offset; + + /// Fixed number of bytes per row that are used to encode null masks. + /// Null masks indicate for a single row which of its key columns are null. + /// Nth bit in the sequence of bytes assigned to a row represents null + /// information for Nth field according to the order in which they are encoded. + int null_masks_bytes_per_row; + + /// Power of 2. Every row will start at the offset aligned to that number of bytes. + int row_alignment; + + /// Power of 2. Must be no greater than row alignment. + /// Every non-power-of-2 binary field and every varbinary field bytes + /// will start aligned to that number of bytes. + int string_alignment; + + /// Metadata of encoded columns in their original order. + std::vector column_metadatas; + + /// Order in which fields are encoded. + std::vector column_order; + + /// Offsets within a row to fields in their encoding order. + std::vector column_offsets; + + /// Rounding up offset to the nearest multiple of alignment value. + /// Alignment must be a power of 2. + static inline uint32_t padding_for_alignment(uint32_t offset, + int required_alignment) { + ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1); + return static_cast((-static_cast(offset)) & + (required_alignment - 1)); + } + + /// Rounding up offset to the beginning of next column, + /// chosing required alignment based on the data type of that column. + static inline uint32_t padding_for_alignment(uint32_t offset, int string_alignment, + const KeyColumnMetadata& col_metadata) { + if (!col_metadata.is_fixed_length || + ARROW_POPCOUNT64(col_metadata.fixed_length) <= 1) { + return 0; + } else { + return padding_for_alignment(offset, string_alignment); + } + } + + /// Returns an array of offsets within a row of ends of varbinary fields. + inline const uint32_t* varbinary_end_array(const uint8_t* row) const { + ARROW_DCHECK(!is_fixed_length); + return reinterpret_cast(row + varbinary_end_array_offset); + } + inline uint32_t* varbinary_end_array(uint8_t* row) const { + ARROW_DCHECK(!is_fixed_length); + return reinterpret_cast(row + varbinary_end_array_offset); + } + + /// Returns the offset within the row and length of the first varbinary field. + inline void first_varbinary_offset_and_length(const uint8_t* row, uint32_t* offset, + uint32_t* length) const { + ARROW_DCHECK(!is_fixed_length); + *offset = fixed_length; + *length = varbinary_end_array(row)[0] - fixed_length; + } + + /// Returns the offset within the row and length of the second and further varbinary + /// fields. + inline void nth_varbinary_offset_and_length(const uint8_t* row, int varbinary_id, + uint32_t* out_offset, + uint32_t* out_length) const { + ARROW_DCHECK(!is_fixed_length); + ARROW_DCHECK(varbinary_id > 0); + const uint32_t* varbinary_end = varbinary_end_array(row); + uint32_t offset = varbinary_end[varbinary_id - 1]; + offset += padding_for_alignment(offset, string_alignment); + *out_offset = offset; + *out_length = varbinary_end[varbinary_id] - offset; + } + + uint32_t encoded_field_order(uint32_t icol) const { return column_order[icol]; } + + uint32_t encoded_field_offset(uint32_t icol) const { return column_offsets[icol]; } + + uint32_t num_cols() const { return static_cast(column_metadatas.size()); } + + uint32_t num_varbinary_cols() const; + + void FromColumnMetadataVector(const std::vector& cols, + int in_row_alignment, int in_string_alignment); + + bool is_compatible(const KeyRowMetadata& other) const; + }; + + class KeyRowArray { + public: + KeyRowArray(); + Status Init(MemoryPool* pool, const KeyRowMetadata& metadata); + void Clean(); + Status AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append); + Status AppendSelectionFrom(const KeyRowArray& from, uint32_t num_rows_to_append, + const uint16_t* source_row_ids); + const KeyRowMetadata& metadata() const { return metadata_; } + int64_t length() const { return num_rows_; } + const uint8_t* data(int i) const { + ARROW_DCHECK(i >= 0 && i <= max_buffers_); + return buffers_[i]; + } + uint8_t* mutable_data(int i) { + ARROW_DCHECK(i >= 0 && i <= max_buffers_); + return mutable_buffers_[i]; + } + const uint32_t* offsets() const { return reinterpret_cast(data(1)); } + uint32_t* mutable_offsets() { return reinterpret_cast(mutable_data(1)); } + const uint8_t* null_masks() const { return null_masks_->data(); } + uint8_t* null_masks() { return null_masks_->mutable_data(); } + + bool has_any_nulls(const KeyEncoderContext* ctx) const; + + private: + Status ResizeFixedLengthBuffers(int64_t num_extra_rows); + Status ResizeOptionalVaryingLengthBuffer(int64_t num_extra_bytes); + + int64_t size_null_masks(int64_t num_rows); + int64_t size_offsets(int64_t num_rows); + int64_t size_rows_fixed_length(int64_t num_rows); + int64_t size_rows_varying_length(int64_t num_bytes); + void update_buffer_pointers(); + + static constexpr int64_t padding_for_vectors = 64; + MemoryPool* pool_; + KeyRowMetadata metadata_; + /// Buffers can only expand during lifetime and never shrink. + std::unique_ptr null_masks_; + std::unique_ptr offsets_; + std::unique_ptr rows_; + static constexpr int max_buffers_ = 3; + const uint8_t* buffers_[max_buffers_]; + uint8_t* mutable_buffers_[max_buffers_]; + int64_t num_rows_; + int64_t rows_capacity_; + int64_t bytes_capacity_; + + // Mutable to allow lazy evaluation + mutable int64_t num_rows_for_has_any_nulls_; + mutable bool has_any_nulls_; + }; + + /// A lightweight description of an array representing one of key columns. + class KeyColumnArray { + public: + KeyColumnArray() = default; + /// Create as a mix of buffers according to the mask from two descriptions + /// (Nth bit is set to 0 if Nth buffer from the first input + /// should be used and is set to 1 otherwise). + /// Metadata is inherited from the first input. + KeyColumnArray(const KeyColumnMetadata& metadata, const KeyColumnArray& left, + const KeyColumnArray& right, int buffer_id_to_replace); + /// Create for reading + KeyColumnArray(const KeyColumnMetadata& metadata, int64_t length, + const uint8_t* buffer0, const uint8_t* buffer1, + const uint8_t* buffer2); + /// Create for writing + KeyColumnArray(const KeyColumnMetadata& metadata, int64_t length, uint8_t* buffer0, + uint8_t* buffer1, uint8_t* buffer2); + /// Create as a window view of original description that is offset + /// by a given number of rows. + /// The number of rows used in offset must be divisible by 8 + /// in order to not split bit vectors within a single byte. + KeyColumnArray(const KeyColumnArray& from, int64_t start, int64_t length); + uint8_t* mutable_data(int i) { + ARROW_DCHECK(i >= 0 && i <= max_buffers_); + return mutable_buffers_[i]; + } + const uint8_t* data(int i) const { + ARROW_DCHECK(i >= 0 && i <= max_buffers_); + return buffers_[i]; + } + uint32_t* mutable_offsets() { return reinterpret_cast(mutable_data(1)); } + const uint32_t* offsets() const { return reinterpret_cast(data(1)); } + const KeyColumnMetadata& metadata() const { return metadata_; } + int64_t length() const { return length_; } + + private: + static constexpr int max_buffers_ = 3; + const uint8_t* buffers_[max_buffers_]; + uint8_t* mutable_buffers_[max_buffers_]; + KeyColumnMetadata metadata_; + int64_t length_; + }; + + void Init(const std::vector& cols, KeyEncoderContext* ctx, + int row_alignment, int string_alignment); + + const KeyRowMetadata& row_metadata() { return row_metadata_; } + + /// Find out the required sizes of all buffers output buffers for encoding + /// (including varying-length buffers). + /// Use that information to resize provided row array so that it can fit + /// encoded data. + Status PrepareOutputForEncode(int64_t start_input_row, int64_t num_input_rows, + KeyRowArray* rows, + const std::vector& all_cols); + + /// Encode a window of column oriented data into the entire output + /// row oriented storage. + /// The output buffers for encoding need to be correctly sized before + /// starting encoding. + void Encode(int64_t start_input_row, int64_t num_input_rows, KeyRowArray* rows, + const std::vector& cols); + + /// Decode a window of row oriented data into a corresponding + /// window of column oriented storage. + /// The output buffers need to be correctly allocated and sized before + /// calling each method. + /// For that reason decoding is split into two functions. + /// The output of the first one, that processes everything except for + /// varying length buffers, can be used to find out required varying + /// length buffers sizes. + void DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output, + int64_t num_rows, const KeyRowArray& rows, + std::vector* cols); + + void DecodeVaryingLengthBuffers(int64_t start_row_input, int64_t start_row_output, + int64_t num_rows, const KeyRowArray& rows, + std::vector* cols); + + private: + /// Prepare column array vectors. + /// Output column arrays represent a range of input column arrays + /// specified by starting row and number of rows. + /// Three vectors are generated: + /// - all columns + /// - fixed-length columns only + /// - varying-length columns only + void PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows, + const std::vector& cols_in); + + class TransformBoolean { + public: + static KeyColumnArray ArrayReplace(const KeyColumnArray& column, + const KeyColumnArray& temp); + static void PreEncode(const KeyColumnArray& input, KeyColumnArray* output, + KeyEncoderContext* ctx); + static void PostDecode(const KeyColumnArray& input, KeyColumnArray* output, + KeyEncoderContext* ctx); + }; + + class EncoderInteger { + public: + static void Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp); + static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx, KeyColumnArray* temp); + static bool UsesTransform(const KeyColumnArray& column); + static KeyColumnArray ArrayReplace(const KeyColumnArray& column, + const KeyColumnArray& temp); + static void PreEncode(const KeyColumnArray& input, KeyColumnArray* output, + KeyEncoderContext* ctx); + static void PostDecode(const KeyColumnArray& input, KeyColumnArray* output, + KeyEncoderContext* ctx); + + private: + static bool IsBoolean(const KeyColumnMetadata& metadata); + }; + + class EncoderBinary { + public: + static void Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp); + static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx, KeyColumnArray* temp); + static bool IsInteger(const KeyColumnMetadata& metadata); + + private: + template + static inline void EncodeDecodeHelper(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray* rows_const, + KeyRowArray* rows_mutable_maybe_null, + const KeyColumnArray* col_const, + KeyColumnArray* col_mutable_maybe_null, + COPY_FN copy_fn); + template + static void EncodeImp(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col); + template + static void DecodeImp(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, const KeyRowArray& rows, + KeyColumnArray* col); +#if defined(ARROW_HAVE_AVX2) + static void EncodeHelper_avx2(bool is_row_fixed_length, uint32_t offset_within_row, + KeyRowArray* rows, const KeyColumnArray& col); + static void DecodeHelper_avx2(bool is_row_fixed_length, uint32_t start_row, + uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col); + template + static void EncodeImp_avx2(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col); + template + static void DecodeImp_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, const KeyRowArray& rows, + KeyColumnArray* col); +#endif + static void ColumnMemsetNulls(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp_vector_16bit, uint8_t byte_value); + template + static void ColumnMemsetNullsImp(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp_vector_16bit, + uint8_t byte_value); + }; + + class EncoderBinaryPair { + public: + static bool CanProcessPair(const KeyColumnMetadata& col1, + const KeyColumnMetadata& col2) { + return EncoderBinary::IsInteger(col1) && EncoderBinary::IsInteger(col2); + } + static void Encode(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col1, const KeyColumnArray& col2, + KeyEncoderContext* ctx, KeyColumnArray* temp1, + KeyColumnArray* temp2); + static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col1, + KeyColumnArray* col2, KeyEncoderContext* ctx, + KeyColumnArray* temp1, KeyColumnArray* temp2); + + private: + template + static void EncodeImp(uint32_t num_rows_to_skip, uint32_t offset_within_row, + KeyRowArray* rows, const KeyColumnArray& col1, + const KeyColumnArray& col2); + template + static void DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row, + uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col1, + KeyColumnArray* col2); +#if defined(ARROW_HAVE_AVX2) + static uint32_t EncodeHelper_avx2(bool is_row_fixed_length, uint32_t col_width, + uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col1, + const KeyColumnArray& col2); + static uint32_t DecodeHelper_avx2(bool is_row_fixed_length, uint32_t col_width, + uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, const KeyRowArray& rows, + KeyColumnArray* col1, KeyColumnArray* col2); + template + static uint32_t EncodeImp_avx2(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col1, + const KeyColumnArray& col2); + template + static uint32_t DecodeImp_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, const KeyRowArray& rows, + KeyColumnArray* col1, KeyColumnArray* col2); +#endif + }; + + class EncoderOffsets { + public: + // In order not to repeat work twice, + // encoding combines in a single pass computing of: + // a) row offsets for varying-length rows + // b) within each new row, the cumulative length array + // of varying-length values within a row. + static void Encode(KeyRowArray* rows, + const std::vector& varbinary_cols, + KeyEncoderContext* ctx); + static void Decode(uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows, + std::vector* varbinary_cols, + const std::vector& varbinary_cols_base_offset, + KeyEncoderContext* ctx); + + private: + static void EncodeImp(uint32_t num_rows_already_processed, KeyRowArray* rows, + const std::vector& varbinary_cols); +#if defined(ARROW_HAVE_AVX2) + static uint32_t EncodeImp_avx2(KeyRowArray* rows, + const std::vector& varbinary_cols, + KeyColumnArray* temp_buffer_32B_per_col); +#endif + }; + + class EncoderVarBinary { + public: + static void Encode(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx); + static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t varbinary_col_id, + const KeyRowArray& rows, KeyColumnArray* col, + KeyEncoderContext* ctx); + + private: + template + static inline void EncodeDecodeHelper(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, + const KeyRowArray* rows_const, + KeyRowArray* rows_mutable_maybe_null, + const KeyColumnArray* col_const, + KeyColumnArray* col_mutable_maybe_null, + COPY_FN copy_fn); + template + static void EncodeImp(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col); + template + static void DecodeImp(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, const KeyRowArray& rows, + KeyColumnArray* col); +#if defined(ARROW_HAVE_AVX2) + static void EncodeHelper_avx2(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col); + static void DecodeHelper_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, const KeyRowArray& rows, + KeyColumnArray* col); + template + static void EncodeImp_avx2(uint32_t varbinary_col_id, KeyRowArray* rows, + const KeyColumnArray& col); + template + static void DecodeImp_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, const KeyRowArray& rows, + KeyColumnArray* col); +#endif + }; + + class EncoderNulls { + public: + static void Encode(KeyRowArray* rows, const std::vector& cols, + KeyEncoderContext* ctx, KeyColumnArray* temp_vector_16bit); + static void Decode(uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows, + std::vector* cols); + }; + + KeyEncoderContext* ctx_; + + // Data initialized once, based on data types of key columns + KeyRowMetadata row_metadata_; + + // Data initialized for each input batch. + // All elements are ordered according to the order of encoded fields in a row. + std::vector batch_all_cols_; + std::vector batch_varbinary_cols_; + std::vector batch_varbinary_cols_base_offsets_; +}; + +template +inline void KeyEncoder::EncoderBinary::EncodeDecodeHelper( + uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray* rows_const, KeyRowArray* rows_mutable_maybe_null, + const KeyColumnArray* col_const, KeyColumnArray* col_mutable_maybe_null, + COPY_FN copy_fn) { + ARROW_DCHECK(col_const && col_const->metadata().is_fixed_length); + uint32_t col_width = col_const->metadata().fixed_length; + + if (is_row_fixed_length) { + uint32_t row_width = rows_const->metadata().fixed_length; + for (uint32_t i = 0; i < num_rows; ++i) { + const uint8_t* src; + uint8_t* dst; + if (is_encoding) { + src = col_const->data(1) + col_width * i; + dst = rows_mutable_maybe_null->mutable_data(1) + row_width * (start_row + i) + + offset_within_row; + } else { + src = rows_const->data(1) + row_width * (start_row + i) + offset_within_row; + dst = col_mutable_maybe_null->mutable_data(1) + col_width * i; + } + copy_fn(dst, src, col_width); + } + } else { + const uint32_t* row_offsets = rows_const->offsets(); + for (uint32_t i = 0; i < num_rows; ++i) { + const uint8_t* src; + uint8_t* dst; + if (is_encoding) { + src = col_const->data(1) + col_width * i; + dst = rows_mutable_maybe_null->mutable_data(2) + row_offsets[start_row + i] + + offset_within_row; + } else { + src = rows_const->data(2) + row_offsets[start_row + i] + offset_within_row; + dst = col_mutable_maybe_null->mutable_data(1) + col_width * i; + } + copy_fn(dst, src, col_width); + } + } +} + +template +inline void KeyEncoder::EncoderVarBinary::EncodeDecodeHelper( + uint32_t start_row, uint32_t num_rows, uint32_t varbinary_col_id, + const KeyRowArray* rows_const, KeyRowArray* rows_mutable_maybe_null, + const KeyColumnArray* col_const, KeyColumnArray* col_mutable_maybe_null, + COPY_FN copy_fn) { + // Column and rows need to be varying length + ARROW_DCHECK(!rows_const->metadata().is_fixed_length && + !col_const->metadata().is_fixed_length); + + const uint32_t* row_offsets_for_batch = rows_const->offsets() + start_row; + const uint32_t* col_offsets = col_const->offsets(); + + uint32_t col_offset_next = col_offsets[0]; + for (uint32_t i = 0; i < num_rows; ++i) { + uint32_t col_offset = col_offset_next; + col_offset_next = col_offsets[i + 1]; + + uint32_t row_offset = row_offsets_for_batch[i]; + const uint8_t* row = rows_const->data(2) + row_offset; + + uint32_t offset_within_row; + uint32_t length; + if (first_varbinary_col) { + rows_const->metadata().first_varbinary_offset_and_length(row, &offset_within_row, + &length); + } else { + rows_const->metadata().nth_varbinary_offset_and_length(row, varbinary_col_id, + &offset_within_row, &length); + } + + row_offset += offset_within_row; + + const uint8_t* src; + uint8_t* dst; + if (is_encoding) { + src = col_const->data(2) + col_offset; + dst = rows_mutable_maybe_null->mutable_data(2) + row_offset; + } else { + src = rows_const->data(2) + row_offset; + dst = col_mutable_maybe_null->mutable_data(2) + col_offset; + } + copy_fn(dst, src, length); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_encode_avx2.cc b/cpp/src/arrow/compute/exec/key_encode_avx2.cc new file mode 100644 index 00000000000..d875412cf88 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_encode_avx2.cc @@ -0,0 +1,545 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/key_encode.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +inline __m256i set_first_n_bytes_avx2(int n) { + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL; + constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL; + + return _mm256_cmpgt_epi8(_mm256_set1_epi8(n), + _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, + kByteSequence16To23, kByteSequence24To31)); +} + +inline __m256i inclusive_prefix_sum_32bit_avx2(__m256i x) { + x = _mm256_add_epi32( + x, _mm256_permutevar8x32_epi32( + _mm256_andnot_si256(_mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0xffffffff), x), + _mm256_setr_epi32(7, 0, 1, 2, 3, 4, 5, 6))); + x = _mm256_add_epi32( + x, _mm256_permute4x64_epi64( + _mm256_andnot_si256( + _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0xffffffff, 0xffffffff), x), + 0x93)); // 0b10010011 + x = _mm256_add_epi32( + x, _mm256_permute4x64_epi64( + _mm256_andnot_si256( + _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0xffffffff, 0xffffffff), x), + 0x4f)); // 0b01001111 + return x; +} + +void KeyEncoder::EncoderBinary::EncodeHelper_avx2(bool is_row_fixed_length, + uint32_t offset_within_row, + KeyRowArray* rows, + const KeyColumnArray& col) { + if (is_row_fixed_length) { + EncodeImp_avx2(offset_within_row, rows, col); + } else { + EncodeImp_avx2(offset_within_row, rows, col); + } +} + +template +void KeyEncoder::EncoderBinary::EncodeImp_avx2(uint32_t offset_within_row, + KeyRowArray* rows, + const KeyColumnArray& col) { + EncodeDecodeHelper( + 0, static_cast(col.length()), offset_within_row, rows, rows, &col, + nullptr, [](uint8_t* dst, const uint8_t* src, int64_t length) { + __m256i* dst256 = reinterpret_cast<__m256i*>(dst); + const __m256i* src256 = reinterpret_cast(src); + uint32_t istripe; + for (istripe = 0; istripe < length / 32; ++istripe) { + _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe)); + } + if ((length % 32) > 0) { + __m256i mask = set_first_n_bytes_avx2(length % 32); + _mm256_storeu_si256( + dst256 + istripe, + _mm256_blendv_epi8(_mm256_loadu_si256(dst256 + istripe), + _mm256_loadu_si256(src256 + istripe), mask)); + } + }); +} + +void KeyEncoder::EncoderBinary::DecodeHelper_avx2(bool is_row_fixed_length, + uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, + KeyColumnArray* col) { + if (is_row_fixed_length) { + DecodeImp_avx2(start_row, num_rows, offset_within_row, rows, col); + } else { + DecodeImp_avx2(start_row, num_rows, offset_within_row, rows, col); + } +} + +template +void KeyEncoder::EncoderBinary::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, + const KeyRowArray& rows, + KeyColumnArray* col) { + EncodeDecodeHelper( + start_row, num_rows, offset_within_row, &rows, nullptr, col, col, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + for (uint32_t istripe = 0; istripe < (length + 31) / 32; ++istripe) { + __m256i* dst256 = reinterpret_cast<__m256i*>(dst); + const __m256i* src256 = reinterpret_cast(src); + _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe)); + } + }); +} + +uint32_t KeyEncoder::EncoderBinaryPair::EncodeHelper_avx2( + bool is_row_fixed_length, uint32_t col_width, uint32_t offset_within_row, + KeyRowArray* rows, const KeyColumnArray& col1, const KeyColumnArray& col2) { + using EncodeImp_avx2_t = + uint32_t (*)(uint32_t, KeyRowArray*, const KeyColumnArray&, const KeyColumnArray&); + static const EncodeImp_avx2_t EncodeImp_avx2_fn[] = { + EncodeImp_avx2, EncodeImp_avx2, EncodeImp_avx2, + EncodeImp_avx2, EncodeImp_avx2, EncodeImp_avx2, + EncodeImp_avx2, EncodeImp_avx2, + }; + int log_col_width = col_width == 8 ? 3 : col_width == 4 ? 2 : col_width == 2 ? 1 : 0; + int dispatch_const = (is_row_fixed_length ? 4 : 0) + log_col_width; + return EncodeImp_avx2_fn[dispatch_const](offset_within_row, rows, col1, col2); +} + +template +uint32_t KeyEncoder::EncoderBinaryPair::EncodeImp_avx2(uint32_t offset_within_row, + KeyRowArray* rows, + const KeyColumnArray& col1, + const KeyColumnArray& col2) { + uint32_t num_rows = static_cast(col1.length()); + ARROW_DCHECK(col_width == 1 || col_width == 2 || col_width == 4 || col_width == 8); + + const uint8_t* col_vals_A = col1.data(1); + const uint8_t* col_vals_B = col2.data(1); + uint8_t* row_vals = is_row_fixed_length ? rows->mutable_data(1) : rows->mutable_data(2); + + constexpr int unroll = 32 / col_width; + + uint32_t num_processed = num_rows / unroll * unroll; + + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + __m256i col_A = _mm256_loadu_si256(reinterpret_cast(col_vals_A) + i); + __m256i col_B = _mm256_loadu_si256(reinterpret_cast(col_vals_B) + i); + __m256i r0, r1; + if (col_width == 1) { + // results in 16-bit outputs in the order: 0..7, 16..23 + r0 = _mm256_unpacklo_epi8(col_A, col_B); + // results in 16-bit outputs in the order: 8..15, 24..31 + r1 = _mm256_unpackhi_epi8(col_A, col_B); + } else if (col_width == 2) { + // results in 32-bit outputs in the order: 0..3, 8..11 + r0 = _mm256_unpacklo_epi16(col_A, col_B); + // results in 32-bit outputs in the order: 4..7, 12..15 + r1 = _mm256_unpackhi_epi16(col_A, col_B); + } else if (col_width == 4) { + // results in 64-bit outputs in the order: 0..1, 4..5 + r0 = _mm256_unpacklo_epi32(col_A, col_B); + // results in 64-bit outputs in the order: 2..3, 6..7 + r1 = _mm256_unpackhi_epi32(col_A, col_B); + } else if (col_width == 8) { + // results in 128-bit outputs in the order: 0, 2 + r0 = _mm256_unpacklo_epi64(col_A, col_B); + // results in 128-bit outputs in the order: 1, 3 + r1 = _mm256_unpackhi_epi64(col_A, col_B); + } + col_A = _mm256_permute2x128_si256(r0, r1, 0x20); + col_B = _mm256_permute2x128_si256(r0, r1, 0x31); + if (col_width == 8) { + __m128i *dst0, *dst1, *dst2, *dst3; + if (is_row_fixed_length) { + uint32_t fixed_length = rows->metadata().fixed_length; + uint8_t* dst = row_vals + offset_within_row + fixed_length * i * unroll; + dst0 = reinterpret_cast<__m128i*>(dst); + dst1 = reinterpret_cast<__m128i*>(dst + fixed_length); + dst2 = reinterpret_cast<__m128i*>(dst + fixed_length * 2); + dst3 = reinterpret_cast<__m128i*>(dst + fixed_length * 3); + } else { + const uint32_t* row_offsets = rows->offsets() + i * unroll; + uint8_t* dst = row_vals + offset_within_row; + dst0 = reinterpret_cast<__m128i*>(dst + row_offsets[0]); + dst1 = reinterpret_cast<__m128i*>(dst + row_offsets[1]); + dst2 = reinterpret_cast<__m128i*>(dst + row_offsets[2]); + dst3 = reinterpret_cast<__m128i*>(dst + row_offsets[3]); + } + _mm_storeu_si128(dst0, _mm256_castsi256_si128(r0)); + _mm_storeu_si128(dst1, _mm256_castsi256_si128(r1)); + _mm_storeu_si128(dst2, _mm256_extracti128_si256(r0, 1)); + _mm_storeu_si128(dst3, _mm256_extracti128_si256(r1, 1)); + + } else { + uint8_t buffer[64]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(buffer), col_A); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(buffer) + 1, col_B); + + if (is_row_fixed_length) { + uint32_t fixed_length = rows->metadata().fixed_length; + uint8_t* dst = row_vals + offset_within_row + fixed_length * i * unroll; + for (int j = 0; j < unroll; ++j) { + if (col_width == 1) { + *reinterpret_cast(dst + fixed_length * j) = + reinterpret_cast(buffer)[j]; + } else if (col_width == 2) { + *reinterpret_cast(dst + fixed_length * j) = + reinterpret_cast(buffer)[j]; + } else if (col_width == 4) { + *reinterpret_cast(dst + fixed_length * j) = + reinterpret_cast(buffer)[j]; + } + } + } else { + const uint32_t* row_offsets = rows->offsets() + i * unroll; + uint8_t* dst = row_vals + offset_within_row; + for (int j = 0; j < unroll; ++j) { + if (col_width == 1) { + *reinterpret_cast(dst + row_offsets[j]) = + reinterpret_cast(buffer)[j]; + } else if (col_width == 2) { + *reinterpret_cast(dst + row_offsets[j]) = + reinterpret_cast(buffer)[j]; + } else if (col_width == 4) { + *reinterpret_cast(dst + row_offsets[j]) = + reinterpret_cast(buffer)[j]; + } + } + } + } + } + + return num_processed; +} + +uint32_t KeyEncoder::EncoderBinaryPair::DecodeHelper_avx2( + bool is_row_fixed_length, uint32_t col_width, uint32_t start_row, uint32_t num_rows, + uint32_t offset_within_row, const KeyRowArray& rows, KeyColumnArray* col1, + KeyColumnArray* col2) { + using DecodeImp_avx2_t = + uint32_t (*)(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col1, KeyColumnArray* col2); + static const DecodeImp_avx2_t DecodeImp_avx2_fn[] = { + DecodeImp_avx2, DecodeImp_avx2, DecodeImp_avx2, + DecodeImp_avx2, DecodeImp_avx2, DecodeImp_avx2, + DecodeImp_avx2, DecodeImp_avx2}; + int log_col_width = col_width == 8 ? 3 : col_width == 4 ? 2 : col_width == 2 ? 1 : 0; + int dispatch_const = log_col_width | (is_row_fixed_length ? 4 : 0); + return DecodeImp_avx2_fn[dispatch_const](start_row, num_rows, offset_within_row, rows, + col1, col2); +} + +template +uint32_t KeyEncoder::EncoderBinaryPair::DecodeImp_avx2( + uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row, + const KeyRowArray& rows, KeyColumnArray* col1, KeyColumnArray* col2) { + ARROW_DCHECK(col_width == 1 || col_width == 2 || col_width == 4 || col_width == 8); + + uint8_t* col_vals_A = col1->mutable_data(1); + uint8_t* col_vals_B = col2->mutable_data(1); + + uint32_t fixed_length = rows.metadata().fixed_length; + const uint32_t* offsets; + const uint8_t* src_base; + if (is_row_fixed_length) { + src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + offsets = nullptr; + } else { + src_base = rows.data(2) + offset_within_row; + offsets = rows.offsets() + start_row; + } + + constexpr int unroll = 32 / col_width; + + uint32_t num_processed = num_rows / unroll * unroll; + + if (col_width == 8) { + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + const __m128i *src0, *src1, *src2, *src3; + if (is_row_fixed_length) { + const uint8_t* src = src_base + (i * unroll) * fixed_length; + src0 = reinterpret_cast(src); + src1 = reinterpret_cast(src + fixed_length); + src2 = reinterpret_cast(src + fixed_length * 2); + src3 = reinterpret_cast(src + fixed_length * 3); + } else { + const uint32_t* row_offsets = offsets + i * unroll; + const uint8_t* src = src_base; + src0 = reinterpret_cast(src + row_offsets[0]); + src1 = reinterpret_cast(src + row_offsets[1]); + src2 = reinterpret_cast(src + row_offsets[2]); + src3 = reinterpret_cast(src + row_offsets[3]); + } + + __m256i r0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(src0)), + _mm_loadu_si128(src1), 1); + __m256i r1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(src2)), + _mm_loadu_si128(src3), 1); + + r0 = _mm256_permute4x64_epi64(r0, 0xd8); // 0b11011000 + r1 = _mm256_permute4x64_epi64(r1, 0xd8); + + // First 128-bit lanes from both inputs + __m256i c1 = _mm256_permute2x128_si256(r0, r1, 0x20); + // Second 128-bit lanes from both inputs + __m256i c2 = _mm256_permute2x128_si256(r0, r1, 0x31); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_A) + i, c1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_B) + i, c2); + } + } else { + uint8_t buffer[64]; + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + if (is_row_fixed_length) { + const uint8_t* src = src_base + (i * unroll) * fixed_length; + for (int j = 0; j < unroll; ++j) { + if (col_width == 1) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + fixed_length * j); + } else if (col_width == 2) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + fixed_length * j); + } else if (col_width == 4) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + fixed_length * j); + } + } + } else { + const uint32_t* row_offsets = offsets + i * unroll; + const uint8_t* src = src_base; + for (int j = 0; j < unroll; ++j) { + if (col_width == 1) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + row_offsets[j]); + } else if (col_width == 2) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + row_offsets[j]); + } else if (col_width == 4) { + reinterpret_cast(buffer)[j] = + *reinterpret_cast(src + row_offsets[j]); + } + } + } + + __m256i r0 = _mm256_loadu_si256(reinterpret_cast(buffer)); + __m256i r1 = _mm256_loadu_si256(reinterpret_cast(buffer) + 1); + + constexpr uint64_t kByteSequence_0_2_4_6_8_10_12_14 = 0x0e0c0a0806040200ULL; + constexpr uint64_t kByteSequence_1_3_5_7_9_11_13_15 = 0x0f0d0b0907050301ULL; + constexpr uint64_t kByteSequence_0_1_4_5_8_9_12_13 = 0x0d0c090805040100ULL; + constexpr uint64_t kByteSequence_2_3_6_7_10_11_14_15 = 0x0f0e0b0a07060302ULL; + + if (col_width == 1) { + // Collect every second byte next to each other + const __m256i shuffle_const = _mm256_setr_epi64x( + kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15, + kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15); + r0 = _mm256_shuffle_epi8(r0, shuffle_const); + r1 = _mm256_shuffle_epi8(r1, shuffle_const); + // 0b11011000 swapping second and third 64-bit lane + r0 = _mm256_permute4x64_epi64(r0, 0xd8); + r1 = _mm256_permute4x64_epi64(r1, 0xd8); + } else if (col_width == 2) { + // Collect every second 16-bit word next to each other + const __m256i shuffle_const = _mm256_setr_epi64x( + kByteSequence_0_1_4_5_8_9_12_13, kByteSequence_2_3_6_7_10_11_14_15, + kByteSequence_0_1_4_5_8_9_12_13, kByteSequence_2_3_6_7_10_11_14_15); + r0 = _mm256_shuffle_epi8(r0, shuffle_const); + r1 = _mm256_shuffle_epi8(r1, shuffle_const); + // 0b11011000 swapping second and third 64-bit lane + r0 = _mm256_permute4x64_epi64(r0, 0xd8); + r1 = _mm256_permute4x64_epi64(r1, 0xd8); + } else if (col_width == 4) { + // Collect every second 32-bit word next to each other + const __m256i permute_const = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + r0 = _mm256_permutevar8x32_epi32(r0, permute_const); + r1 = _mm256_permutevar8x32_epi32(r1, permute_const); + } + + // First 128-bit lanes from both inputs + __m256i c1 = _mm256_permute2x128_si256(r0, r1, 0x20); + // Second 128-bit lanes from both inputs + __m256i c2 = _mm256_permute2x128_si256(r0, r1, 0x31); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_A) + i, c1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_B) + i, c2); + } + } + + return num_processed; +} + +uint32_t KeyEncoder::EncoderOffsets::EncodeImp_avx2( + KeyRowArray* rows, const std::vector& varbinary_cols, + KeyColumnArray* temp_buffer_32B_per_col) { + ARROW_DCHECK(temp_buffer_32B_per_col->metadata().is_fixed_length && + temp_buffer_32B_per_col->metadata().fixed_length == + static_cast(sizeof(uint32_t)) && + temp_buffer_32B_per_col->length() >= + static_cast(varbinary_cols.size()) * 8); + ARROW_DCHECK(varbinary_cols.size() > 0); + + int row_alignment = rows->metadata().row_alignment; + int string_alignment = rows->metadata().string_alignment; + + uint32_t* row_offsets = rows->mutable_offsets(); + uint8_t* row_values = rows->mutable_data(2); + uint32_t num_rows = static_cast(varbinary_cols[0].length()); + + constexpr int unroll = 8; + uint32_t num_processed = num_rows / unroll * unroll; + uint32_t* temp_varbinary_ends = + reinterpret_cast(temp_buffer_32B_per_col->mutable_data(1)); + + row_offsets[0] = 0; + + __m256i row_offset = _mm256_setzero_si256(); + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + // Zero out lengths for nulls. + // Add lengths of all columns to get row size. + // Store in temp buffer varbinary field ends while summing their lengths. + + __m256i offset_within_row = _mm256_set1_epi32(rows->metadata().fixed_length); + + for (size_t col = 0; col < varbinary_cols.size(); ++col) { + const uint32_t* col_offsets = varbinary_cols[col].offsets(); + __m256i col_length = _mm256_sub_epi32( + _mm256_loadu_si256(reinterpret_cast(col_offsets + 1) + i), + _mm256_loadu_si256(reinterpret_cast(col_offsets + 0) + i)); + + const uint8_t* non_nulls = varbinary_cols[col].data(0); + if (non_nulls && non_nulls[i] != 0xff) { + // Zero out lengths for values that are not null + const __m256i individual_bits = + _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); + __m256i null_mask = _mm256_cmpeq_epi32( + _mm256_setzero_si256(), + _mm256_and_si256(_mm256_set1_epi32(non_nulls[i]), individual_bits)); + col_length = _mm256_andnot_si256(null_mask, col_length); + } + + __m256i padding = + _mm256_and_si256(_mm256_sub_epi32(_mm256_setzero_si256(), offset_within_row), + _mm256_set1_epi32(string_alignment - 1)); + offset_within_row = _mm256_add_epi32(offset_within_row, padding); + offset_within_row = _mm256_add_epi32(offset_within_row, col_length); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp_varbinary_ends) + col, + offset_within_row); + } + + __m256i padding = + _mm256_and_si256(_mm256_sub_epi32(_mm256_setzero_si256(), offset_within_row), + _mm256_set1_epi32(row_alignment - 1)); + offset_within_row = _mm256_add_epi32(offset_within_row, padding); + + // Inclusive prefix sum of 32-bit elements + __m256i row_offset_delta = inclusive_prefix_sum_32bit_avx2(offset_within_row); + row_offset = _mm256_add_epi32( + _mm256_permutevar8x32_epi32(row_offset, _mm256_set1_epi32(7)), row_offset_delta); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(row_offsets + 1) + i, row_offset); + + // Output varbinary ends for all fields in each row + for (size_t col = 0; col < varbinary_cols.size(); ++col) { + for (uint32_t row = 0; row < unroll; ++row) { + uint32_t* dst = rows->metadata().varbinary_end_array( + row_values + row_offsets[i * unroll + row]) + + col; + const uint32_t* src = temp_varbinary_ends + (col * unroll + row); + *dst = *src; + } + } + } + + return num_processed; +} + +void KeyEncoder::EncoderVarBinary::EncodeHelper_avx2(uint32_t varbinary_col_id, + KeyRowArray* rows, + const KeyColumnArray& col) { + if (varbinary_col_id == 0) { + EncodeImp_avx2(varbinary_col_id, rows, col); + } else { + EncodeImp_avx2(varbinary_col_id, rows, col); + } +} + +template +void KeyEncoder::EncoderVarBinary::EncodeImp_avx2(uint32_t varbinary_col_id, + KeyRowArray* rows, + const KeyColumnArray& col) { + EncodeDecodeHelper( + 0, static_cast(col.length()), varbinary_col_id, rows, rows, &col, nullptr, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + __m256i* dst256 = reinterpret_cast<__m256i*>(dst); + const __m256i* src256 = reinterpret_cast(src); + uint32_t istripe; + for (istripe = 0; istripe < length / 32; ++istripe) { + _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe)); + } + if ((length % 32) > 0) { + __m256i mask = set_first_n_bytes_avx2(length % 32); + _mm256_storeu_si256( + dst256 + istripe, + _mm256_blendv_epi8(_mm256_loadu_si256(dst256 + istripe), + _mm256_loadu_si256(src256 + istripe), mask)); + } + }); +} + +void KeyEncoder::EncoderVarBinary::DecodeHelper_avx2(uint32_t start_row, + uint32_t num_rows, + uint32_t varbinary_col_id, + const KeyRowArray& rows, + KeyColumnArray* col) { + if (varbinary_col_id == 0) { + DecodeImp_avx2(start_row, num_rows, varbinary_col_id, rows, col); + } else { + DecodeImp_avx2(start_row, num_rows, varbinary_col_id, rows, col); + } +} + +template +void KeyEncoder::EncoderVarBinary::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows, + uint32_t varbinary_col_id, + const KeyRowArray& rows, + KeyColumnArray* col) { + EncodeDecodeHelper( + start_row, num_rows, varbinary_col_id, &rows, nullptr, col, col, + [](uint8_t* dst, const uint8_t* src, int64_t length) { + for (uint32_t istripe = 0; istripe < (length + 31) / 32; ++istripe) { + __m256i* dst256 = reinterpret_cast<__m256i*>(dst); + const __m256i* src256 = reinterpret_cast(src); + _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe)); + } + }); +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash.cc b/cpp/src/arrow/compute/exec/key_hash.cc new file mode 100644 index 00000000000..081411e708e --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_hash.cc @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/key_hash.h" + +#include + +#include +#include + +#include "arrow/compute/exec/util.h" + +namespace arrow { +namespace compute { + +inline uint32_t Hashing::avalanche_helper(uint32_t acc) { + acc ^= (acc >> 15); + acc *= PRIME32_2; + acc ^= (acc >> 13); + acc *= PRIME32_3; + acc ^= (acc >> 16); + return acc; +} + +void Hashing::avalanche(int64_t hardware_flags, uint32_t num_keys, uint32_t* hashes) { + uint32_t processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + int tail = num_keys % 8; + avalanche_avx2(num_keys - tail, hashes); + processed = num_keys - tail; + } +#endif + for (uint32_t i = processed; i < num_keys; ++i) { + hashes[i] = avalanche_helper(hashes[i]); + } +} + +inline uint32_t Hashing::combine_accumulators(const uint32_t acc1, const uint32_t acc2, + const uint32_t acc3, const uint32_t acc4) { + return ROTL(acc1, 1) + ROTL(acc2, 7) + ROTL(acc3, 12) + ROTL(acc4, 18); +} + +inline void Hashing::helper_8B(uint32_t key_length, uint32_t num_keys, + const uint8_t* keys, uint32_t* hashes) { + ARROW_DCHECK(key_length <= 8); + uint64_t mask = ~0ULL >> (8 * (8 - key_length)); + constexpr uint64_t multiplier = 14029467366897019727ULL; + uint32_t offset = 0; + for (uint32_t ikey = 0; ikey < num_keys; ++ikey) { + uint64_t x = *reinterpret_cast(keys + offset); + x &= mask; + hashes[ikey] = static_cast(BYTESWAP(x * multiplier)); + offset += key_length; + } +} + +inline void Hashing::helper_stripe(uint32_t offset, uint64_t mask_hi, const uint8_t* keys, + uint32_t& acc1, uint32_t& acc2, uint32_t& acc3, + uint32_t& acc4) { + uint64_t v1 = reinterpret_cast(keys + offset)[0]; + // We do not need to mask v1, because we will not process a stripe + // unless at least 9 bytes of it are part of the key. + uint64_t v2 = reinterpret_cast(keys + offset)[1]; + v2 &= mask_hi; + uint32_t x1 = static_cast(v1); + uint32_t x2 = static_cast(v1 >> 32); + uint32_t x3 = static_cast(v2); + uint32_t x4 = static_cast(v2 >> 32); + acc1 += x1 * PRIME32_2; + acc1 = ROTL(acc1, 13) * PRIME32_1; + acc2 += x2 * PRIME32_2; + acc2 = ROTL(acc2, 13) * PRIME32_1; + acc3 += x3 * PRIME32_2; + acc3 = ROTL(acc3, 13) * PRIME32_1; + acc4 += x4 * PRIME32_2; + acc4 = ROTL(acc4, 13) * PRIME32_1; +} + +void Hashing::helper_stripes(int64_t hardware_flags, uint32_t num_keys, + uint32_t key_length, const uint8_t* keys, uint32_t* hash) { + uint32_t processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + int tail = num_keys % 2; + helper_stripes_avx2(num_keys - tail, key_length, keys, hash); + processed = num_keys - tail; + } +#endif + + // If length modulo stripe length is less than or equal 8, round down to the nearest 16B + // boundary (8B ending will be processed in a separate function), otherwise round up. + const uint32_t num_stripes = (key_length + 7) / 16; + uint64_t mask_hi = + ~0ULL >> + (8 * ((num_stripes * 16 > key_length) ? num_stripes * 16 - key_length : 0)); + + for (uint32_t i = processed; i < num_keys; ++i) { + uint32_t acc1, acc2, acc3, acc4; + acc1 = static_cast( + (static_cast(PRIME32_1) + static_cast(PRIME32_2)) & + 0xffffffff); + acc2 = PRIME32_2; + acc3 = 0; + acc4 = static_cast(-static_cast(PRIME32_1)); + uint32_t offset = i * key_length; + for (uint32_t stripe = 0; stripe < num_stripes - 1; ++stripe) { + helper_stripe(offset, ~0ULL, keys, acc1, acc2, acc3, acc4); + offset += 16; + } + helper_stripe(offset, mask_hi, keys, acc1, acc2, acc3, acc4); + hash[i] = combine_accumulators(acc1, acc2, acc3, acc4); + } +} + +inline uint32_t Hashing::helper_tail(uint32_t offset, uint64_t mask, const uint8_t* keys, + uint32_t acc) { + uint64_t v = reinterpret_cast(keys + offset)[0]; + v &= mask; + uint32_t x1 = static_cast(v); + uint32_t x2 = static_cast(v >> 32); + acc += x1 * PRIME32_3; + acc = ROTL(acc, 17) * PRIME32_4; + acc += x2 * PRIME32_3; + acc = ROTL(acc, 17) * PRIME32_4; + return acc; +} + +void Hashing::helper_tails(int64_t hardware_flags, uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash) { + uint32_t processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + int tail = num_keys % 8; + helper_tails_avx2(num_keys - tail, key_length, keys, hash); + processed = num_keys - tail; + } +#endif + uint64_t mask = ~0ULL >> (8 * (((key_length % 8) == 0) ? 0 : 8 - (key_length % 8))); + uint32_t offset = key_length / 16 * 16; + offset += processed * key_length; + for (uint32_t i = processed; i < num_keys; ++i) { + hash[i] = helper_tail(offset, mask, keys, hash[i]); + offset += key_length; + } +} + +void Hashing::hash_fixed(int64_t hardware_flags, uint32_t num_keys, uint32_t length_key, + const uint8_t* keys, uint32_t* hashes) { + ARROW_DCHECK(length_key > 0); + + if (length_key <= 8) { + helper_8B(length_key, num_keys, keys, hashes); + return; + } + helper_stripes(hardware_flags, num_keys, length_key, keys, hashes); + if ((length_key % 16) > 0 && (length_key % 16) <= 8) { + helper_tails(hardware_flags, num_keys, length_key, keys, hashes); + } + avalanche(hardware_flags, num_keys, hashes); +} + +void Hashing::hash_varlen_helper(uint32_t length, const uint8_t* key, uint32_t* acc) { + for (uint32_t i = 0; i < length / 16; ++i) { + for (int j = 0; j < 4; ++j) { + uint32_t lane = reinterpret_cast(key)[i * 4 + j]; + acc[j] += (lane * PRIME32_2); + acc[j] = ROTL(acc[j], 13); + acc[j] *= PRIME32_1; + } + } + + int tail = length % 16; + if (tail) { + uint64_t last_stripe[2]; + const uint64_t* last_stripe_base = + reinterpret_cast(key + length - (length % 16)); + last_stripe[0] = last_stripe_base[0]; + uint64_t mask = ~0ULL >> (8 * ((length + 7) / 8 * 8 - length)); + if (tail <= 8) { + last_stripe[1] = 0; + last_stripe[0] &= mask; + } else { + last_stripe[1] = last_stripe_base[1]; + last_stripe[1] &= mask; + } + for (int j = 0; j < 4; ++j) { + uint32_t lane = reinterpret_cast(last_stripe)[j]; + acc[j] += (lane * PRIME32_2); + acc[j] = ROTL(acc[j], 13); + acc[j] *= PRIME32_1; + } + } +} + +void Hashing::hash_varlen(int64_t hardware_flags, uint32_t num_rows, + const uint32_t* offsets, const uint8_t* concatenated_keys, + uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row + uint32_t* hashes) { +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + hash_varlen_avx2(num_rows, offsets, concatenated_keys, temp_buffer, hashes); + } else { +#endif + for (uint32_t i = 0; i < num_rows; ++i) { + uint32_t acc[4]; + acc[0] = static_cast( + (static_cast(PRIME32_1) + static_cast(PRIME32_2)) & + 0xffffffff); + acc[1] = PRIME32_2; + acc[2] = 0; + acc[3] = static_cast(-static_cast(PRIME32_1)); + uint32_t length = offsets[i + 1] - offsets[i]; + hash_varlen_helper(length, concatenated_keys + offsets[i], acc); + hashes[i] = combine_accumulators(acc[0], acc[1], acc[2], acc[3]); + } + avalanche(hardware_flags, num_rows, hashes); +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash.h b/cpp/src/arrow/compute/exec/key_hash.h new file mode 100644 index 00000000000..7f8ab5185cc --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_hash.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(ARROW_HAVE_AVX2) +#include +#endif + +#include + +#include "arrow/compute/exec/util.h" + +namespace arrow { +namespace compute { + +// Implementations are based on xxh3 32-bit algorithm description from: +// https://github.com/Cyan4973/xxHash/blob/dev/doc/xxhash_spec.md +// +class Hashing { + public: + static void hash_fixed(int64_t hardware_flags, uint32_t num_keys, uint32_t length_key, + const uint8_t* keys, uint32_t* hashes); + + static void hash_varlen(int64_t hardware_flags, uint32_t num_rows, + const uint32_t* offsets, const uint8_t* concatenated_keys, + uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row + uint32_t* hashes); + + private: + static const uint32_t PRIME32_1 = 0x9E3779B1; // 0b10011110001101110111100110110001 + static const uint32_t PRIME32_2 = 0x85EBCA77; // 0b10000101111010111100101001110111 + static const uint32_t PRIME32_3 = 0xC2B2AE3D; // 0b11000010101100101010111000111101 + static const uint32_t PRIME32_4 = 0x27D4EB2F; // 0b00100111110101001110101100101111 + static const uint32_t PRIME32_5 = 0x165667B1; // 0b00010110010101100110011110110001 + + // Avalanche + static inline uint32_t avalanche_helper(uint32_t acc); +#if defined(ARROW_HAVE_AVX2) + static void avalanche_avx2(uint32_t num_keys, uint32_t* hashes); +#endif + static void avalanche(int64_t hardware_flags, uint32_t num_keys, uint32_t* hashes); + + // Accumulator combine + static inline uint32_t combine_accumulators(const uint32_t acc1, const uint32_t acc2, + const uint32_t acc3, const uint32_t acc4); +#if defined(ARROW_HAVE_AVX2) + static inline uint64_t combine_accumulators_avx2(__m256i acc); +#endif + + // Helpers + static inline void helper_8B(uint32_t key_length, uint32_t num_keys, + const uint8_t* keys, uint32_t* hashes); + static inline void helper_stripe(uint32_t offset, uint64_t mask_hi, const uint8_t* keys, + uint32_t& acc1, uint32_t& acc2, uint32_t& acc3, + uint32_t& acc4); + static inline uint32_t helper_tail(uint32_t offset, uint64_t mask, const uint8_t* keys, + uint32_t acc); +#if defined(ARROW_HAVE_AVX2) + static void helper_stripes_avx2(uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash); + static void helper_tails_avx2(uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash); +#endif + static void helper_stripes(int64_t hardware_flags, uint32_t num_keys, + uint32_t key_length, const uint8_t* keys, uint32_t* hash); + static void helper_tails(int64_t hardware_flags, uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash); + + static void hash_varlen_helper(uint32_t length, const uint8_t* key, uint32_t* acc); +#if defined(ARROW_HAVE_AVX2) + static void hash_varlen_avx2(uint32_t num_rows, const uint32_t* offsets, + const uint8_t* concatenated_keys, + uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row + uint32_t* hashes); +#endif +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash_avx2.cc b/cpp/src/arrow/compute/exec/key_hash_avx2.cc new file mode 100644 index 00000000000..b58db015088 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_hash_avx2.cc @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/key_hash.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +void Hashing::avalanche_avx2(uint32_t num_keys, uint32_t* hashes) { + constexpr int unroll = 8; + ARROW_DCHECK(num_keys % unroll == 0); + for (uint32_t i = 0; i < num_keys / unroll; ++i) { + __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 15)); + hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_2)); + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 13)); + hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_3)); + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 16)); + _mm256_storeu_si256((reinterpret_cast<__m256i*>(hashes)) + i, hash); + } +} + +inline uint64_t Hashing::combine_accumulators_avx2(__m256i acc) { + acc = _mm256_or_si256( + _mm256_sllv_epi32(acc, _mm256_setr_epi32(1, 7, 12, 18, 1, 7, 12, 18)), + _mm256_srlv_epi32(acc, _mm256_setr_epi32(32 - 1, 32 - 7, 32 - 12, 32 - 18, 32 - 1, + 32 - 7, 32 - 12, 32 - 18))); + acc = _mm256_add_epi32(acc, _mm256_shuffle_epi32(acc, 0xee)); // 0b11101110 + acc = _mm256_add_epi32(acc, _mm256_srli_epi64(acc, 32)); + acc = _mm256_permutevar8x32_epi32(acc, _mm256_setr_epi32(0, 4, 0, 0, 0, 0, 0, 0)); + uint64_t result = _mm256_extract_epi64(acc, 0); + return result; +} + +void Hashing::helper_stripes_avx2(uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash) { + constexpr int unroll = 2; + ARROW_DCHECK(num_keys % unroll == 0); + + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + + const __m256i mask_last_stripe = + (key_length % 16) <= 8 + ? _mm256_set1_epi8(static_cast(0xffU)) + : _mm256_cmpgt_epi8(_mm256_set1_epi8(key_length % 16), + _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, + kByteSequence0To7, kByteSequence8To15)); + + // If length modulo stripe length is less than or equal 8, round down to the nearest 16B + // boundary (8B ending will be processed in a separate function), otherwise round up. + const uint32_t num_stripes = (key_length + 7) / 16; + for (uint32_t i = 0; i < num_keys / unroll; ++i) { + __m256i acc = _mm256_setr_epi32( + static_cast((static_cast(PRIME32_1) + PRIME32_2) & + 0xffffffff), + PRIME32_2, 0, static_cast(-static_cast(PRIME32_1)), + static_cast((static_cast(PRIME32_1) + PRIME32_2) & + 0xffffffff), + PRIME32_2, 0, static_cast(-static_cast(PRIME32_1))); + auto key0 = reinterpret_cast(keys + key_length * 2 * i); + auto key1 = reinterpret_cast(keys + key_length * 2 * i + key_length); + for (uint32_t stripe = 0; stripe < num_stripes - 1; ++stripe) { + auto key_stripe = + _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(key0 + stripe)), + _mm_loadu_si128(key1 + stripe), 1); + acc = _mm256_add_epi32( + acc, _mm256_mullo_epi32(key_stripe, _mm256_set1_epi32(PRIME32_2))); + acc = _mm256_or_si256(_mm256_slli_epi32(acc, 13), _mm256_srli_epi32(acc, 32 - 13)); + acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_1)); + } + auto key_stripe = _mm256_inserti128_si256( + _mm256_castsi128_si256(_mm_loadu_si128(key0 + num_stripes - 1)), + _mm_loadu_si128(key1 + num_stripes - 1), 1); + key_stripe = _mm256_and_si256(key_stripe, mask_last_stripe); + acc = _mm256_add_epi32(acc, + _mm256_mullo_epi32(key_stripe, _mm256_set1_epi32(PRIME32_2))); + acc = _mm256_or_si256(_mm256_slli_epi32(acc, 13), _mm256_srli_epi32(acc, 32 - 13)); + acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_1)); + uint64_t result = combine_accumulators_avx2(acc); + reinterpret_cast(hash)[i] = result; + } +} + +void Hashing::helper_tails_avx2(uint32_t num_keys, uint32_t key_length, + const uint8_t* keys, uint32_t* hash) { + constexpr int unroll = 8; + ARROW_DCHECK(num_keys % unroll == 0); + auto keys_i64 = reinterpret_cast(keys); + + // Process between 1 and 8 last bytes of each key, starting from 16B boundary. + // The caller needs to make sure that there are no more than 8 bytes to process after + // that 16B boundary. + uint32_t first_offset = key_length - (key_length % 16); + __m256i mask = _mm256_set1_epi64x((~0ULL) >> (8 * (8 - (key_length % 16)))); + __m256i offset = + _mm256_setr_epi32(0, key_length, key_length * 2, key_length * 3, key_length * 4, + key_length * 5, key_length * 6, key_length * 7); + offset = _mm256_add_epi32(offset, _mm256_set1_epi32(first_offset)); + __m256i offset_incr = _mm256_set1_epi32(key_length * 8); + + for (uint32_t i = 0; i < num_keys / unroll; ++i) { + auto v1 = _mm256_i32gather_epi64(keys_i64, _mm256_castsi256_si128(offset), 1); + auto v2 = _mm256_i32gather_epi64(keys_i64, _mm256_extracti128_si256(offset, 1), 1); + v1 = _mm256_and_si256(v1, mask); + v2 = _mm256_and_si256(v2, mask); + v1 = _mm256_permutevar8x32_epi32(v1, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + v2 = _mm256_permutevar8x32_epi32(v2, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + auto x1 = _mm256_permute2x128_si256(v1, v2, 0x20); + auto x2 = _mm256_permute2x128_si256(v1, v2, 0x31); + __m256i acc = _mm256_loadu_si256((reinterpret_cast(hash)) + i); + + acc = _mm256_add_epi32(acc, _mm256_mullo_epi32(x1, _mm256_set1_epi32(PRIME32_3))); + acc = _mm256_or_si256(_mm256_slli_epi32(acc, 17), _mm256_srli_epi32(acc, 32 - 17)); + acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_4)); + + acc = _mm256_add_epi32(acc, _mm256_mullo_epi32(x2, _mm256_set1_epi32(PRIME32_3))); + acc = _mm256_or_si256(_mm256_slli_epi32(acc, 17), _mm256_srli_epi32(acc, 32 - 17)); + acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_4)); + + _mm256_storeu_si256((reinterpret_cast<__m256i*>(hash)) + i, acc); + + offset = _mm256_add_epi32(offset, offset_incr); + } +} + +void Hashing::hash_varlen_avx2(uint32_t num_rows, const uint32_t* offsets, + const uint8_t* concatenated_keys, + uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row + uint32_t* hashes) { + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; + + const __m128i sequence = _mm_set_epi64x(kByteSequence8To15, kByteSequence0To7); + const __m128i acc_init = _mm_setr_epi32( + static_cast((static_cast(PRIME32_1) + PRIME32_2) & 0xffffffff), + PRIME32_2, 0, static_cast(-static_cast(PRIME32_1))); + + // Variable length keys are always processed as a sequence of 16B stripes, + // with the last stripe, if extending past the end of the key, having extra bytes set to + // 0 on the fly. + for (uint32_t ikey = 0; ikey < num_rows; ++ikey) { + uint32_t begin = offsets[ikey]; + uint32_t end = offsets[ikey + 1]; + uint32_t length = end - begin; + const uint8_t* base = concatenated_keys + begin; + + __m128i acc = acc_init; + + uint32_t i; + for (i = 0; i < (length - 1) / 16; ++i) { + __m128i key_stripe = _mm_loadu_si128(reinterpret_cast(base) + i); + acc = _mm_add_epi32(acc, _mm_mullo_epi32(key_stripe, _mm_set1_epi32(PRIME32_2))); + acc = _mm_or_si128(_mm_slli_epi32(acc, 13), _mm_srli_epi32(acc, 32 - 13)); + acc = _mm_mullo_epi32(acc, _mm_set1_epi32(PRIME32_1)); + } + __m128i key_stripe = _mm_loadu_si128(reinterpret_cast(base) + i); + __m128i mask = _mm_cmpgt_epi8(_mm_set1_epi8(((length - 1) % 16) + 1), sequence); + key_stripe = _mm_and_si128(key_stripe, mask); + acc = _mm_add_epi32(acc, _mm_mullo_epi32(key_stripe, _mm_set1_epi32(PRIME32_2))); + acc = _mm_or_si128(_mm_slli_epi32(acc, 13), _mm_srli_epi32(acc, 32 - 13)); + acc = _mm_mullo_epi32(acc, _mm_set1_epi32(PRIME32_1)); + + _mm_storeu_si128(reinterpret_cast<__m128i*>(temp_buffer) + ikey, acc); + } + + // Combine accumulators and perform avalanche + constexpr int unroll = 8; + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + __m256i accA = + _mm256_loadu_si256(reinterpret_cast(temp_buffer) + 4 * i + 0); + __m256i accB = + _mm256_loadu_si256(reinterpret_cast(temp_buffer) + 4 * i + 1); + __m256i accC = + _mm256_loadu_si256(reinterpret_cast(temp_buffer) + 4 * i + 2); + __m256i accD = + _mm256_loadu_si256(reinterpret_cast(temp_buffer) + 4 * i + 3); + // Transpose 2x 4x4 32-bit matrices + __m256i r0 = _mm256_unpacklo_epi32(accA, accB); + __m256i r1 = _mm256_unpackhi_epi32(accA, accB); + __m256i r2 = _mm256_unpacklo_epi32(accC, accD); + __m256i r3 = _mm256_unpackhi_epi32(accC, accD); + accA = _mm256_unpacklo_epi64(r0, r2); + accB = _mm256_unpackhi_epi64(r0, r2); + accC = _mm256_unpacklo_epi64(r1, r3); + accD = _mm256_unpackhi_epi64(r1, r3); + // _rotl(accA, 1) + // _rotl(accB, 7) + // _rotl(accC, 12) + // _rotl(accD, 18) + accA = _mm256_or_si256(_mm256_slli_epi32(accA, 1), _mm256_srli_epi32(accA, 32 - 1)); + accB = _mm256_or_si256(_mm256_slli_epi32(accB, 7), _mm256_srli_epi32(accB, 32 - 7)); + accC = _mm256_or_si256(_mm256_slli_epi32(accC, 12), _mm256_srli_epi32(accC, 32 - 12)); + accD = _mm256_or_si256(_mm256_slli_epi32(accD, 18), _mm256_srli_epi32(accD, 32 - 18)); + accA = _mm256_add_epi32(_mm256_add_epi32(accA, accB), _mm256_add_epi32(accC, accD)); + // avalanche + __m256i hash = accA; + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 15)); + hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_2)); + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 13)); + hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_3)); + hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 16)); + // Store. + // At this point, because of way 2x 4x4 transposition was done, output hashes are in + // order: 0, 2, 4, 6, 1, 3, 5, 7. Bring back the original order. + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(hashes) + i, + _mm256_permutevar8x32_epi32(hash, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7))); + } + // Process the tail of up to 7 hashes + for (uint32_t i = num_rows - num_rows % unroll; i < num_rows; ++i) { + uint32_t* temp_buffer_base = temp_buffer + i * 4; + uint32_t acc = ROTL(temp_buffer_base[0], 1) + ROTL(temp_buffer_base[1], 7) + + ROTL(temp_buffer_base[2], 12) + ROTL(temp_buffer_base[3], 18); + + // avalanche + acc ^= (acc >> 15); + acc *= PRIME32_2; + acc ^= (acc >> 13); + acc *= PRIME32_3; + acc ^= (acc >> 16); + + hashes[i] = acc; + } +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc new file mode 100644 index 00000000000..c48487793e0 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -0,0 +1,603 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/key_map.h" + +#include + +#include +#include + +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using BitUtil::CountLeadingZeros; + +namespace compute { + +constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + +// Search status bytes inside a block of 8 slots (64-bit word). +// Try to find a slot that contains a 7-bit stamp matching the one provided. +// There are three possible outcomes: +// 1. A matching slot is found. +// -> Return its index between 0 and 7 and set match found flag. +// 2. A matching slot is not found and there is an empty slot in the block. +// -> Return the index of the first empty slot and clear match found flag. +// 3. A matching slot is not found and there are no empty slots in the block. +// -> Return 8 as the output slot index and clear match found flag. +// +// Optionally an index of the first slot to start the search from can be specified. +// In this case slots before it will be ignored. +// +template +inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, + int* out_slot, int* out_match_found) { + // Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80. + uint64_t block_high_bits = block & kHighBitOfEachByte; + + // Replicate 7-bit stamp to all non-empty slots, leaving zeroes for empty slots. + uint64_t stamp_pattern = stamp * ((block_high_bits ^ kHighBitOfEachByte) >> 7); + + // If we xor this pattern with block status bytes we get in individual bytes: + // a) 0x00, for filled slots matching the stamp, + // b) 0x00 < x < 0x80, for filled slots not matching the stamp, + // c) 0x80, for empty slots. + uint64_t block_xor_pattern = block ^ stamp_pattern; + + // If we then add 0x7f to every byte, we get: + // a) 0x7F + // b) 0x80 <= x < 0xFF + // c) 0xFF + uint64_t match_base = block_xor_pattern + ~kHighBitOfEachByte; + + // The highest bit now tells us if we have a match (0) or not (1). + // We will negate the bits so that match is represented by a set bit. + uint64_t matches = ~match_base; + + // Clear 7 non-relevant bits in each byte. + // Also clear bytes that correspond to slots that we were supposed to + // skip due to provided start slot index. + // Note: the highest byte corresponds to the first slot. + if (use_start_slot) { + matches &= kHighBitOfEachByte >> (8 * start_slot); + } else { + matches &= kHighBitOfEachByte; + } + + // We get 0 if there are no matches + *out_match_found = (matches == 0 ? 0 : 1); + + // Now if we or with the highest bits of the block and scan zero bits in reverse, + // we get 8x slot index that we were looking for. + // This formula works in all three cases a), b) and c). + *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); +} + +// This call follows the call to search_block. +// The input slot index is the output returned by it, which is a value from 0 to 8, +// with 8 indicating that both: no match was found and there were no empty slots. +// +// If the slot corresponds to a non-empty slot return a group id associated with it. +// Otherwise return any group id from any of the slots or +// zero, which is the default value stored in empty slots. +// +inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, + uint64_t group_id_mask) { + // Input slot can be equal to 8, in which case we need to output any valid group id + // value, so we take the one from slot 0 in the block. + int clamped_slot = slot & 7; + + // Group id values for all 8 slots in the block are bit-packed and follow the status + // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In + // that case we can extract group id using aligned 64-bit word access. + int num_groupid_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + + int bit_offset = clamped_slot * num_groupid_bits; + const uint64_t* group_id_bytes = + reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); + uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; + + return group_id; +} + +// Return global slot id (the index including the information about the block) +// where the search should continue if the first comparison fails. +// This function always follows search_block and receives the slot id returned by it. +// +inline uint64_t SwissTable::next_slot_to_visit(uint64_t block_index, int slot, + int match_found) { + // The result should be taken modulo the number of all slots in all blocks, + // but here we allow it to take a value one above the last slot index. + // Modulo operation is postponed to later. + return block_index * 8 + slot + match_found; +} + +// Implements first (fast-path, optimistic) lookup. +// Searches for a match only within the start block and +// trying only the first slot with a matching stamp. +// +// Comparison callback needed for match verification is done outside of this function. +// Match bit vector filled by it only indicates finding a matching stamp in a slot. +// +template +void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, + const uint32_t* hashes, uint8_t* out_match_bitvector, + uint32_t* out_groupids, uint32_t* out_slot_ids) { + // Clear the output bit vector + memset(out_match_bitvector, 0, (num_keys + 7) / 8); + + // Based on the size of the table, prepare bit number constants. + uint32_t stamp_mask = (1 << bits_stamp_) - 1; + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint32_t groupid_mask = (1 << num_groupid_bits) - 1; + + for (int i = 0; i < num_keys; ++i) { + int id; + if (use_selection) { + id = selection[i]; + } else { + id = i; + } + + // Extract from hash: block index and stamp + // + uint32_t hash = hashes[id]; + uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_); + uint32_t stamp = iblock & stamp_mask; + iblock >>= bits_stamp_; + + uint32_t num_block_bytes = num_groupid_bits + 8; + const uint8_t* blockbase = reinterpret_cast(blocks_) + + static_cast(iblock) * num_block_bytes; + uint64_t block = *reinterpret_cast(blockbase); + + // Call helper functions to obtain the output triplet: + // - match (of a stamp) found flag + // - group id for key comparison + // - slot to resume search from in case of no match or false positive + int match_found; + int islot_in_block; + search_block(block, stamp, 0, &islot_in_block, &match_found); + uint64_t groupid = extract_group_id(blockbase, islot_in_block, groupid_mask); + ARROW_DCHECK(groupid < num_inserted_ || num_inserted_ == 0); + uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found); + + out_match_bitvector[id / 8] |= match_found << (id & 7); + out_groupids[id] = static_cast(groupid); + out_slot_ids[id] = static_cast(islot); + } +} + +// How many groups we can keep in the hash table without the need for resizing. +// When we reach this limit, we need to break processing of any further rows and resize. +// +uint64_t SwissTable::num_groups_for_resize() const { + // Resize small hash tables when 50% full (up to 12KB). + // Resize large hash tables when 75% full. + constexpr int log_blocks_small_ = 9; + uint64_t num_slots = 1ULL << (log_blocks_ + 3); + if (log_blocks_ <= log_blocks_small_) { + return num_slots / 2; + } else { + return num_slots * 3 / 4; + } +} + +uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) { + uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1; + return global_slot_id & global_slot_id_mask; +} + +// Run a single round of slot search - comparison / insert - filter unprocessed. +// Update selection vector to reflect which items have been processed. +// Ids in selection vector do not have to be sorted. +// +Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, uint32_t* inout_next_slot_ids) { + auto num_groups_limit = num_groups_for_resize(); + ARROW_DCHECK(num_inserted_ < num_groups_limit); + + // Temporary arrays are of limited size. + // The input needs to be split into smaller portions if it exceeds that limit. + // + ARROW_DCHECK(*inout_num_selected <= static_cast(1 << log_minibatch_)); + + // We will split input row ids into three categories: + // - needing to visit next block [0] + // - needing comparison [1] + // - inserted [2] + // + auto ids_inserted_buf = + util::TempVectorHolder(temp_stack_, *inout_num_selected); + auto ids_for_comparison_buf = + util::TempVectorHolder(temp_stack_, *inout_num_selected); + constexpr int category_nomatch = 0; + constexpr int category_cmp = 1; + constexpr int category_inserted = 2; + int num_ids[3]; + num_ids[0] = num_ids[1] = num_ids[2] = 0; + uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(), + ids_inserted_buf.mutable_data()}; + auto push_id = [&num_ids, &ids](int category, int id) { + ids[category][num_ids[category]++] = static_cast(id); + }; + + uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; + constexpr uint64_t stamp_mask = 0x7f; + uint64_t num_block_bytes = (8 + num_groupid_bits); + + uint32_t num_processed; + for (num_processed = 0; + // Second condition in for loop: + // We need to break processing and have the caller of this function + // resize hash table if we reach the limit of the number of groups present. + num_processed < *inout_num_selected && + num_inserted_ + num_ids[category_inserted] < num_groups_limit; + ++num_processed) { + // row id in original batch + int id = inout_selection[num_processed]; + + uint64_t slot_id = wrap_global_slot_id(inout_next_slot_ids[id]); + uint64_t block_id = slot_id >> 3; + uint32_t hash = hashes[id]; + uint8_t* blockbase = blocks_ + num_block_bytes * block_id; + uint64_t block = *reinterpret_cast(blockbase); + uint64_t stamp = (hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask; + int start_slot = (slot_id & 7); + + bool isempty = (blockbase[7 - start_slot] == 0x80); + if (isempty) { + // If we reach the empty slot we insert key for new group + + blockbase[7 - start_slot] = static_cast(stamp); + uint32_t group_id = num_inserted_ + num_ids[category_inserted]; + int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + + // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. + // In that case we can insert group id value using aligned 64-bit word access. + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + reinterpret_cast(blockbase + 8)[groupid_bit_offset >> 6] |= + (static_cast(group_id) << (groupid_bit_offset & 63)); + + hashes_[slot_id] = hash; + out_group_ids[id] = group_id; + push_id(category_inserted, id); + } else { + // We search for a slot with a matching stamp within a single block. + // We append row id to the appropriate sequence of ids based on + // whether the match has been found or not. + + int new_match_found; + int new_slot; + search_block(block, static_cast(stamp), start_slot, &new_slot, + &new_match_found); + auto new_groupid = + static_cast(extract_group_id(blockbase, new_slot, groupid_mask)); + ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]); + new_slot = + static_cast(next_slot_to_visit(block_id, new_slot, new_match_found)); + inout_next_slot_ids[id] = new_slot; + out_group_ids[id] = new_groupid; + push_id(new_match_found, id); + } + } + + // Copy keys for newly inserted rows using callback + RETURN_NOT_OK(append_impl_(num_ids[category_inserted], ids[category_inserted])); + num_inserted_ += num_ids[category_inserted]; + + // Evaluate comparisons and append ids of rows that failed it to the non-match set. + uint32_t num_not_equal; + equal_impl_(num_ids[category_cmp], ids[category_cmp], out_group_ids, &num_not_equal, + ids[category_nomatch] + num_ids[category_nomatch]); + num_ids[category_nomatch] += num_not_equal; + + // Append ids of any unprocessed entries if we aborted processing due to the need + // to resize. + if (num_processed < *inout_num_selected) { + memmove(ids[category_nomatch] + num_ids[category_nomatch], + inout_selection + num_processed, + sizeof(uint16_t) * (*inout_num_selected - num_processed)); + num_ids[category_nomatch] += (*inout_num_selected - num_processed); + } + + *out_need_resize = (num_inserted_ == num_groups_limit); + *inout_num_selected = num_ids[category_nomatch]; + return Status::OK(); +} + +// Use hashes and callbacks to find group ids for already existing keys and +// to insert and report newly assigned group ids for new keys. +// +Status SwissTable::map(const int num_keys, const uint32_t* hashes, + uint32_t* out_groupids) { + // Temporary buffers have limited size. + // Caller is responsible for splitting larger input arrays into smaller chunks. + ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); + + // Allocate temporary buffers with a lifetime of this function + auto match_bitvector_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); + auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint32_t* slot_ids = slot_ids_buf.mutable_data(); + auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint16_t* ids = ids_buf.mutable_data(); + uint32_t num_ids; + + // First-pass processing. + // Optimistically use simplified lookup involving only a start block to find + // a single group id candidate for every input. +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { + if (log_blocks_ <= 4) { + int tail = num_keys % 32; + int delta = num_keys - tail; + lookup_1_avx2_x32(num_keys - tail, hashes, match_bitvector, out_groupids, slot_ids); + lookup_1_avx2_x8(tail, hashes + delta, match_bitvector + delta / 8, + out_groupids + delta, slot_ids + delta); + } else { + lookup_1_avx2_x8(num_keys, hashes, match_bitvector, out_groupids, slot_ids); + } + } else { +#endif + lookup_1(nullptr, num_keys, hashes, match_bitvector, out_groupids, slot_ids); +#if defined(ARROW_HAVE_AVX2) + } +#endif + + int64_t num_matches = + arrow::internal::CountSetBits(match_bitvector, /*offset=*/0, num_keys); + + // After the first-pass processing count rows with matches (based on stamp comparison) + // and decide based on their percentage whether to call dense or sparse comparison + // function. Dense comparison means evaluating it for all inputs, even if the matching + // stamp was not found. It may be cheaper to evaluate comparison for all inputs if the + // extra cost of filtering is higher than the wasted processing of rows with no match. + // + // Dense comparison can only be used if there is at least one inserted key, + // because otherwise there is no key to compare to. + // + if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { + // Dense comparisons + equal_impl_(num_keys, nullptr, out_groupids, &num_ids, ids); + } else { + // Sparse comparisons that involve filtering the input set of keys + auto ids_cmp_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint16_t* ids_cmp = ids_cmp_buf.mutable_data(); + int num_ids_result; + util::BitUtil::bits_split_indexes(hardware_flags_, num_keys, match_bitvector, + &num_ids_result, ids, ids_cmp); + num_ids = num_ids_result; + uint32_t num_not_equal; + equal_impl_(num_keys - num_ids, ids_cmp, out_groupids, &num_not_equal, ids + num_ids); + num_ids += num_not_equal; + } + + do { + // A single round of slow-pass (robust) lookup or insert. + // A single round ends with either a single comparison verifying the match candidate + // or inserting a new key. A single round of slow-pass may return early if we reach + // the limit of the number of groups due to inserts of new keys. In that case we need + // to resize and recalculating starting global slot ids for new bigger hash table. + bool out_of_capacity; + RETURN_NOT_OK( + lookup_2(hashes, &num_ids, ids, &out_of_capacity, out_groupids, slot_ids)); + if (out_of_capacity) { + RETURN_NOT_OK(grow_double()); + // Reset start slot ids for still unprocessed input keys. + // + for (uint32_t i = 0; i < num_ids; ++i) { + // First slot in the new starting block + slot_ids[ids[i]] = (hashes[ids[i]] >> (bits_hash_ - log_blocks_)) * 8; + } + } + } while (num_ids > 0); + + return Status::OK(); +} + +Status SwissTable::grow_double() { + // Before and after metadata + int num_group_id_bits_before = num_groupid_bits_from_log_blocks(log_blocks_); + int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1); + uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before); + int log_blocks_before = log_blocks_; + int log_blocks_after = log_blocks_ + 1; + uint64_t block_size_before = (8 + num_group_id_bits_before); + uint64_t block_size_after = (8 + num_group_id_bits_after); + uint64_t block_size_total_before = (block_size_before << log_blocks_before) + padding_; + uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_; + uint64_t hashes_size_total_before = + (bits_hash_ / 8 * (1 << (log_blocks_before + 3))) + padding_; + uint64_t hashes_size_total_after = + (bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_; + constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1; + + // Allocate new buffers + uint8_t* blocks_new; + RETURN_NOT_OK(pool_->Allocate(block_size_total_after, &blocks_new)); + memset(blocks_new, 0, block_size_total_after); + uint8_t* hashes_new_8B; + uint32_t* hashes_new; + RETURN_NOT_OK(pool_->Allocate(hashes_size_total_after, &hashes_new_8B)); + hashes_new = reinterpret_cast(hashes_new_8B); + + // First pass over all old blocks. + // Reinsert entries that were not in the overflow block + // (block other than selected by hash bits corresponding to the entry). + for (int i = 0; i < (1 << log_blocks_); ++i) { + // How many full slots in this block + uint8_t* block_base = blocks_ + i * block_size_before; + uint8_t* double_block_base_new = blocks_new + 2 * i * block_size_after; + uint64_t block = *reinterpret_cast(block_base); + + auto full_slots = + static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); + int full_slots_new[2]; + full_slots_new[0] = full_slots_new[1] = 0; + *reinterpret_cast(double_block_base_new) = kHighBitOfEachByte; + *reinterpret_cast(double_block_base_new + block_size_after) = + kHighBitOfEachByte; + + for (int j = 0; j < full_slots; ++j) { + uint64_t slot_id = i * 8 + j; + uint32_t hash = hashes_[slot_id]; + uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after); + bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); + if (is_overflow_entry) { + continue; + } + + int ihalf = block_id_new & 1; + uint8_t stamp_new = + hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask; + uint64_t group_id_bit_offs = j * num_group_id_bits_before; + uint64_t group_id = (*reinterpret_cast(block_base + 8 + + (group_id_bit_offs >> 3)) >> + (group_id_bit_offs & 7)) & + group_id_mask_before; + + uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf]; + hashes_new[slot_id_new] = hash; + uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after; + block_base_new[7 - full_slots_new[ihalf]] = stamp_new; + int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; + *reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |= + (group_id << (group_id_bit_offs_new & 7)); + full_slots_new[ihalf]++; + } + } + + // Second pass over all old blocks. + // Reinsert entries that were in an overflow block. + for (int i = 0; i < (1 << log_blocks_); ++i) { + // How many full slots in this block + uint8_t* block_base = blocks_ + i * block_size_before; + uint64_t block = *reinterpret_cast(block_base); + int full_slots = static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); + + for (int j = 0; j < full_slots; ++j) { + uint64_t slot_id = i * 8 + j; + uint32_t hash = hashes_[slot_id]; + uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after); + bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); + if (!is_overflow_entry) { + continue; + } + + uint64_t group_id_bit_offs = j * num_group_id_bits_before; + uint64_t group_id = (*reinterpret_cast(block_base + 8 + + (group_id_bit_offs >> 3)) >> + (group_id_bit_offs & 7)) & + group_id_mask_before; + uint8_t stamp_new = + hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask; + + uint8_t* block_base_new = blocks_new + block_id_new * block_size_after; + uint64_t block_new = *reinterpret_cast(block_base_new); + int full_slots_new = + static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); + while (full_slots_new == 8) { + block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1); + block_base_new = blocks_new + block_id_new * block_size_after; + block_new = *reinterpret_cast(block_base_new); + full_slots_new = + static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); + } + + hashes_new[block_id_new * 8 + full_slots_new] = hash; + block_base_new[7 - full_slots_new] = stamp_new; + int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; + *reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |= + (group_id << (group_id_bit_offs_new & 7)); + } + } + + pool_->Free(blocks_, block_size_total_before); + pool_->Free(reinterpret_cast(hashes_), hashes_size_total_before); + log_blocks_ = log_blocks_after; + blocks_ = blocks_new; + hashes_ = hashes_new; + + return Status::OK(); +} + +Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, + util::TempVectorStack* temp_stack, int log_minibatch, + EqualImpl equal_impl, AppendImpl append_impl) { + hardware_flags_ = hardware_flags; + pool_ = pool; + temp_stack_ = temp_stack; + log_minibatch_ = log_minibatch; + equal_impl_ = equal_impl; + append_impl_ = append_impl; + + log_blocks_ = 0; + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + num_inserted_ = 0; + + const uint64_t block_bytes = 8 + num_groupid_bits; + const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_; + RETURN_NOT_OK(pool_->Allocate(slot_bytes, &blocks_)); + + // Make sure group ids are initially set to zero for all slots. + memset(blocks_, 0, slot_bytes); + + // Initialize all status bytes to represent an empty slot. + for (uint64_t i = 0; i < (static_cast(1) << log_blocks_); ++i) { + *reinterpret_cast(blocks_ + i * block_bytes) = kHighBitOfEachByte; + } + + uint64_t num_slots = 1ULL << (log_blocks_ + 3); + const uint64_t hash_size = sizeof(uint32_t); + const uint64_t hash_bytes = hash_size * num_slots + padding_; + uint8_t* hashes8; + RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8)); + hashes_ = reinterpret_cast(hashes8); + + return Status::OK(); +} + +void SwissTable::cleanup() { + if (blocks_) { + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const uint64_t block_bytes = 8 + num_groupid_bits; + const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_; + pool_->Free(blocks_, slot_bytes); + blocks_ = nullptr; + } + if (hashes_) { + uint64_t num_slots = 1ULL << (log_blocks_ + 3); + const uint64_t hash_size = sizeof(uint32_t); + const uint64_t hash_bytes = hash_size * num_slots + padding_; + pool_->Free(reinterpret_cast(hashes_), hash_bytes); + hashes_ = nullptr; + } + log_blocks_ = 0; + num_inserted_ = 0; +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h new file mode 100644 index 00000000000..8c472736ec4 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace compute { + +class SwissTable { + public: + SwissTable() = default; + ~SwissTable() { cleanup(); } + + using EqualImpl = + std::function; + using AppendImpl = std::function; + + Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, + int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + void cleanup(); + + Status map(const int ckeys, const uint32_t* hashes, uint32_t* outgroupids); + + private: + // Lookup helpers + + /// \brief Scan bytes in block in reverse and stop as soon + /// as a position of interest is found. + /// + /// Positions of interest: + /// a) slot with a matching stamp is encountered, + /// b) first empty slot is encountered, + /// c) we reach the end of the block. + /// + /// \param[in] block 8 byte block of hash table + /// \param[in] stamp 7 bits of hash used as a stamp + /// \param[in] start_slot Index of the first slot in the block to start search from. We + /// assume that this index always points to a non-empty slot, equivalently + /// that it comes before any empty slots. (Used only by one template + /// variant.) + /// \param[out] out_slot index corresponding to the discovered position of interest (8 + /// represents end of block). + /// \param[out] out_match_found an integer flag (0 or 1) indicating if we found a + /// matching stamp. + template + inline void search_block(uint64_t block, int stamp, int start_slot, int* out_slot, + int* out_match_found); + + /// \brief Extract group id for a given slot in a given block. + /// + /// Group ids follow in memory after 64-bit block data. + /// Maximum number of groups inserted is equal to the number + /// of all slots in all blocks, which is 8 * the number of blocks. + /// Group ids are bit packed using that maximum to determine the necessary number of + /// bits. + inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, + uint64_t group_id_mask); + + inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, int match_found); + + inline void insert(uint8_t* block_base, uint64_t slot_id, uint32_t hash, uint8_t stamp, + uint32_t group_id); + + inline uint64_t num_groups_for_resize() const; + + inline uint64_t wrap_global_slot_id(uint64_t global_slot_id); + + // First hash table access + // Find first match in the start block if exists. + // Possible cases: + // 1. Stamp match in a block + // 2. No stamp match in a block, no empty buckets in a block + // 3. No stamp match in a block, empty buckets in a block + // + template + void lookup_1(const uint16_t* selection, const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint32_t* out_group_ids, + uint32_t* out_slot_ids); +#if defined(ARROW_HAVE_AVX2) + void lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint32_t* out_group_ids, + uint32_t* out_next_slot_ids); + void lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint32_t* out_group_ids, + uint32_t* out_next_slot_ids); +#endif + + // Completing hash table lookup post first access + Status lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + + // Resize small hash tables when 50% full (up to 8KB). + // Resize large hash tables when 75% full. + Status grow_double(); + + static int num_groupid_bits_from_log_blocks(int log_blocks) { + int required_bits = log_blocks + 3; + return required_bits <= 8 ? 8 + : required_bits <= 16 ? 16 : required_bits <= 32 ? 32 : 64; + } + + // Use 32-bit hash for now + static constexpr int bits_hash_ = 32; + + // Number of hash bits stored in slots in a block. + // The highest bits of hash determine block id. + // The next set of highest bits is a "stamp" stored in a slot in a block. + static constexpr int bits_stamp_ = 7; + + // Padding bytes added at the end of buffers for ease of SIMD access + static constexpr int padding_ = 64; + + int log_minibatch_; + // Base 2 log of the number of blocks + int log_blocks_ = 0; + // Number of keys inserted into hash table + uint32_t num_inserted_ = 0; + + // Data for blocks. + // Each block has 8 status bytes for 8 slots, followed by 8 bit packed group ids for + // these slots. In 8B status word, the order of bytes is reversed. Group ids are in + // normal order. There is 64B padding at the end. + // + // 0 byte - 7 bucket | 1. byte - 6 bucket | ... + // --------------------------------------------------- + // | Empty bit* | Empty bit | + // --------------------------------------------------- + // | 7-bit hash | 7-bit hash | + // --------------------------------------------------- + // * Empty bucket has value 0x80. Non-empty bucket has highest bit set to 0. + // + uint8_t* blocks_; + + // Array of hashes of values inserted into slots. + // Undefined if the corresponding slot is empty. + // There is 64B padding at the end. + uint32_t* hashes_; + + int64_t hardware_flags_; + MemoryPool* pool_; + util::TempVectorStack* temp_stack_; + + EqualImpl equal_impl_; + AppendImpl append_impl_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_map_avx2.cc b/cpp/src/arrow/compute/exec/key_map_avx2.cc new file mode 100644 index 00000000000..a2efb4d1bb9 --- /dev/null +++ b/cpp/src/arrow/compute/exec/key_map_avx2.cc @@ -0,0 +1,407 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/key_map.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +// Why it is OK to round up number of rows internally: +// All of the buffers: hashes, out_match_bitvector, out_group_ids, out_next_slot_ids +// are temporary buffers of group id mapping. +// Temporary buffers are buffers that live only within the boundaries of a single +// minibatch. Temporary buffers add 64B at the end, so that SIMD code does not have to +// worry about reading and writing outside of the end of the buffer up to 64B. If the +// hashes array contains garbage after the last element, it cannot cause computation to +// fail, since any random data is a valid hash for the purpose of lookup. +// +// This is more or less translation of equivalent scalar code, adjusted for a different +// instruction set (e.g. missing leading zero count instruction). +// +void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint32_t* out_group_ids, + uint32_t* out_next_slot_ids) { + // Number of inputs processed together in a loop + constexpr int unroll = 8; + + const int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint32_t group_id_mask = ~static_cast(0) >> (32 - num_group_id_bits); + const __m256i* vhash_ptr = reinterpret_cast(hashes); + const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); + + // TODO: explain why it is ok to process hashes outside of buffer boundaries + for (int i = 0; i < ((num_hashes + unroll - 1) / unroll); ++i) { + constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL; + constexpr uint64_t kByteSequenceOfPowersOf2 = 0x8040201008040201ULL; + + // Calculate block index and hash stamp for a byte in a block + // + __m256i vhash = _mm256_loadu_si256(vhash_ptr + i); + __m256i vblock_id = _mm256_srlv_epi32( + vhash, _mm256_set1_epi32(bits_hash_ - bits_stamp_ - log_blocks_)); + __m256i vstamp = _mm256_and_si256(vblock_id, vstamp_mask); + vblock_id = _mm256_srli_epi32(vblock_id, bits_stamp_); + + // We now split inputs and process 4 at a time, + // in order to process 64-bit blocks + // + __m256i vblock_offset = + _mm256_mullo_epi32(vblock_id, _mm256_set1_epi32(num_group_id_bits + 8)); + __m256i voffset_A = _mm256_and_si256(vblock_offset, _mm256_set1_epi64x(0xffffffff)); + __m256i vstamp_A = _mm256_and_si256(vstamp, _mm256_set1_epi64x(0xffffffff)); + __m256i voffset_B = _mm256_srli_epi64(vblock_offset, 32); + __m256i vstamp_B = _mm256_srli_epi64(vstamp, 32); + + auto blocks_i64 = reinterpret_cast(blocks_); + auto vblock_A = _mm256_i64gather_epi64(blocks_i64, voffset_A, 1); + auto vblock_B = _mm256_i64gather_epi64(blocks_i64, voffset_B, 1); + __m256i vblock_highbits_A = + _mm256_cmpeq_epi8(vblock_A, _mm256_set1_epi8(static_cast(0x80))); + __m256i vblock_highbits_B = + _mm256_cmpeq_epi8(vblock_B, _mm256_set1_epi8(static_cast(0x80))); + __m256i vbyte_repeat_pattern = + _mm256_setr_epi64x(0ULL, kEachByteIs8, 0ULL, kEachByteIs8); + vstamp_A = _mm256_shuffle_epi8( + vstamp_A, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_A)); + vstamp_B = _mm256_shuffle_epi8( + vstamp_B, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_B)); + __m256i vmatches_A = _mm256_cmpeq_epi8(vblock_A, vstamp_A); + __m256i vmatches_B = _mm256_cmpeq_epi8(vblock_B, vstamp_B); + __m256i vmatch_found = _mm256_andnot_si256( + _mm256_blend_epi32(_mm256_cmpeq_epi64(vmatches_A, _mm256_setzero_si256()), + _mm256_cmpeq_epi64(vmatches_B, _mm256_setzero_si256()), + 0xaa), // 0b10101010 + _mm256_set1_epi8(static_cast(0xff))); + vmatches_A = + _mm256_sad_epu8(_mm256_and_si256(_mm256_or_si256(vmatches_A, vblock_highbits_A), + _mm256_set1_epi64x(kByteSequenceOfPowersOf2)), + _mm256_setzero_si256()); + vmatches_B = + _mm256_sad_epu8(_mm256_and_si256(_mm256_or_si256(vmatches_B, vblock_highbits_B), + _mm256_set1_epi64x(kByteSequenceOfPowersOf2)), + _mm256_setzero_si256()); + __m256i vmatches = _mm256_or_si256(vmatches_A, _mm256_slli_epi64(vmatches_B, 32)); + + // We are now back to processing 8 at a time. + // Each lane contains 8-bit bit vector marking slots that are matches. + // We need to find leading zeroes count for all slots. + // + // Emulating lzcnt in lowest bytes of 32-bit elements + __m256i vgt = _mm256_cmpgt_epi32(_mm256_set1_epi32(16), vmatches); + __m256i vnext_slot_id = + _mm256_blendv_epi8(_mm256_srli_epi32(vmatches, 4), + _mm256_and_si256(vmatches, _mm256_set1_epi32(0x0f)), vgt); + vnext_slot_id = _mm256_shuffle_epi8( + _mm256_setr_epi8(4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 2, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0), + vnext_slot_id); + vnext_slot_id = + _mm256_add_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(0xff)), + _mm256_and_si256(vgt, _mm256_set1_epi32(4))); + + // Lookup group ids + // + __m256i vgroupid_bit_offset = + _mm256_mullo_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(7)), + _mm256_set1_epi32(num_group_id_bits)); + + // This only works for up to 25 bits per group id, since it uses 32-bit gather + // TODO: make sure this will never get called when there are more than 2^25 groups. + __m256i vgroupid = + _mm256_add_epi32(_mm256_srli_epi32(vgroupid_bit_offset, 3), + _mm256_add_epi32(vblock_offset, _mm256_set1_epi32(8))); + vgroupid = _mm256_i32gather_epi32(reinterpret_cast(blocks_), vgroupid, 1); + vgroupid = _mm256_srlv_epi32( + vgroupid, _mm256_and_si256(vgroupid_bit_offset, _mm256_set1_epi32(7))); + vgroupid = _mm256_and_si256(vgroupid, _mm256_set1_epi32(group_id_mask)); + + // Convert slot id relative to the block to slot id relative to the beginnning of the + // table + // + vnext_slot_id = _mm256_add_epi32( + _mm256_add_epi32(vnext_slot_id, + _mm256_and_si256(vmatch_found, _mm256_set1_epi32(1))), + _mm256_slli_epi32(vblock_id, 3)); + + // Convert match found vector from 32-bit elements to bit vector + out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found), + 0x11111111); // 0b00010001 repeated 4x + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, vgroupid); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + i, vnext_slot_id); + } +} + +// Take a set of 16 64-bit elements, +// Output one AVX2 register per byte (0 to 7), containing a sequence of 16 bytes, +// one from each input 64-bit word, all from the same position in 64-bit word. +// 16 bytes are replicated in lower and upper half of each output register. +// +inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256i word3, + __m256i& byte0, __m256i& byte1, __m256i& byte2, + __m256i& byte3, __m256i& byte4, __m256i& byte5, + __m256i& byte6, __m256i& byte7) { + __m256i word01lo = _mm256_unpacklo_epi8( + word0, word1); // {a0, e0, a1, e1, ... a7, e7, c0, g0, c1, g1, ... c7, g7} + __m256i word23lo = _mm256_unpacklo_epi8( + word2, word3); // {i0, m0, i1, m1, ... i7, m7, k0, o0, k1, o1, ... k7, o7} + __m256i word01hi = _mm256_unpackhi_epi8( + word0, word1); // {b0, f0, b1, f1, ... b7, f1, d0, h0, d1, h1, ... d7, h7} + __m256i word23hi = _mm256_unpackhi_epi8( + word2, word3); // {j0, n0, j1, n1, ... j7, n7, l0, p0, l1, p1, ... l7, p7} + + __m256i a = + _mm256_unpacklo_epi16(word01lo, word01hi); // {a0, e0, b0, f0, ... a3, e3, b3, f3, + // c0, g0, d0, h0, ... c3, g3, d3, h3} + __m256i b = + _mm256_unpacklo_epi16(word23lo, word23hi); // {i0, m0, j0, n0, ... i3, m3, j3, n3, + // k0, o0, l0, p0, ... k3, o3, l3, p3} + __m256i c = + _mm256_unpackhi_epi16(word01lo, word01hi); // {a4, e4, b4, f4, ... a7, e7, b7, f7, + // c4, g4, d4, h4, ... c7, g7, d7, h7} + __m256i d = + _mm256_unpackhi_epi16(word23lo, word23hi); // {i4, m4, j4, n4, ... i7, m7, j7, n7, + // k4, o4, l4, p4, ... k7, o7, l7, p7} + + __m256i byte01 = _mm256_unpacklo_epi32( + a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, c0, g0, + // d0, h0, k0, o0, l0, p0, ...} + __m256i shuffle_const = + _mm256_setr_epi8(0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15, 0, 2, 8, 10, + 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15); + byte01 = _mm256_permute4x64_epi64( + byte01, 0xd8); // 11011000 b - swapping middle two 64-bit elements + byte01 = _mm256_shuffle_epi8(byte01, shuffle_const); + __m256i byte23 = _mm256_unpackhi_epi32(a, b); + byte23 = _mm256_permute4x64_epi64(byte23, 0xd8); + byte23 = _mm256_shuffle_epi8(byte23, shuffle_const); + __m256i byte45 = _mm256_unpacklo_epi32(c, d); + byte45 = _mm256_permute4x64_epi64(byte45, 0xd8); + byte45 = _mm256_shuffle_epi8(byte45, shuffle_const); + __m256i byte67 = _mm256_unpackhi_epi32(c, d); + byte67 = _mm256_permute4x64_epi64(byte67, 0xd8); + byte67 = _mm256_shuffle_epi8(byte67, shuffle_const); + + byte0 = _mm256_permute4x64_epi64(byte01, 0x44); // 01000100 b + byte1 = _mm256_permute4x64_epi64(byte01, 0xee); // 11101110 b + byte2 = _mm256_permute4x64_epi64(byte23, 0x44); // 01000100 b + byte3 = _mm256_permute4x64_epi64(byte23, 0xee); // 11101110 b + byte4 = _mm256_permute4x64_epi64(byte45, 0x44); // 01000100 b + byte5 = _mm256_permute4x64_epi64(byte45, 0xee); // 11101110 b + byte6 = _mm256_permute4x64_epi64(byte67, 0x44); // 01000100 b + byte7 = _mm256_permute4x64_epi64(byte67, 0xee); // 11101110 b +} + +// This one can only process a multiple of 32 values. +// The caller needs to process the remaining tail, if the input is not divisible by 32, +// using a different method. +// TODO: Explain the idea behind storing arrays in SIMD registers. +// Explain why it is faster with SIMD than using memory loads. +void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint32_t* out_group_ids, + uint32_t* out_next_slot_ids) { + constexpr int unroll = 32; + + // There is a limit on the number of input blocks, + // because we want to store all their data in a set of AVX2 registers. + ARROW_DCHECK(log_blocks_ <= 4); + + // Remember that block bytes and group id bytes are in opposite orders in memory of hash + // table. We put them in the same order. + __m256i vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, + vblock_byte5, vblock_byte6, vblock_byte7; + __m256i vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, + vgroupid_byte5, vgroupid_byte6, vgroupid_byte7; + // What we output if there is no match in the block + __m256i vslot_empty_or_end; + + constexpr uint32_t k4ByteSequence_0_4_8_12 = 0x0c080400; + constexpr uint32_t k4ByteSequence_1_5_9_13 = 0x0d090501; + constexpr uint32_t k4ByteSequence_2_6_10_14 = 0x0e0a0602; + constexpr uint32_t k4ByteSequence_3_7_11_15 = 0x0f0b0703; + constexpr uint64_t kEachByteIs1 = 0x0101010101010101ULL; + constexpr uint64_t kByteSequence7DownTo0 = 0x0001020304050607ULL; + constexpr uint64_t kByteSequence15DownTo8 = 0x08090A0B0C0D0E0FULL; + + // Bit unpack group ids into 1B. + // Assemble the sequence of block bytes. + uint64_t block_bytes[16]; + uint64_t groupid_bytes[16]; + const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint64_t bit_unpack_mask = ((1 << num_groupid_bits) - 1) * kEachByteIs1; + for (int i = 0; i < (1 << log_blocks_); ++i) { + uint64_t in_groupids = + *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i + 8); + uint64_t in_blockbytes = + *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i); + groupid_bytes[i] = _pdep_u64(in_groupids, bit_unpack_mask); + block_bytes[i] = in_blockbytes; + } + + // Split a sequence of 64-bit words into SIMD vectors holding individual bytes + __m256i vblock_words0 = + _mm256_loadu_si256(reinterpret_cast(block_bytes) + 0); + __m256i vblock_words1 = + _mm256_loadu_si256(reinterpret_cast(block_bytes) + 1); + __m256i vblock_words2 = + _mm256_loadu_si256(reinterpret_cast(block_bytes) + 2); + __m256i vblock_words3 = + _mm256_loadu_si256(reinterpret_cast(block_bytes) + 3); + // Reverse the bytes in blocks + __m256i vshuffle_const = + _mm256_setr_epi64x(kByteSequence7DownTo0, kByteSequence15DownTo8, + kByteSequence7DownTo0, kByteSequence15DownTo8); + vblock_words0 = _mm256_shuffle_epi8(vblock_words0, vshuffle_const); + vblock_words1 = _mm256_shuffle_epi8(vblock_words1, vshuffle_const); + vblock_words2 = _mm256_shuffle_epi8(vblock_words2, vshuffle_const); + vblock_words3 = _mm256_shuffle_epi8(vblock_words3, vshuffle_const); + split_bytes_avx2(vblock_words0, vblock_words1, vblock_words2, vblock_words3, + vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, + vblock_byte5, vblock_byte6, vblock_byte7); + split_bytes_avx2( + _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 0), + _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 1), + _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 2), + _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 3), + vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, + vgroupid_byte5, vgroupid_byte6, vgroupid_byte7); + + // Calculate the slot to output when there is no match in a block. + // It will be the index of the first empty slot or 8 (the number of slots in block) + // if there are no empty slots. + vslot_empty_or_end = _mm256_set1_epi8(8); + { + __m256i vis_empty; +#define CMP(VBLOCKBYTE, BYTENUM) \ + vis_empty = \ + _mm256_cmpeq_epi8(VBLOCKBYTE, _mm256_set1_epi8(static_cast(0x80))); \ + vslot_empty_or_end = \ + _mm256_blendv_epi8(vslot_empty_or_end, _mm256_set1_epi8(BYTENUM), vis_empty); + CMP(vblock_byte7, 7); + CMP(vblock_byte6, 6); + CMP(vblock_byte5, 5); + CMP(vblock_byte4, 4); + CMP(vblock_byte3, 3); + CMP(vblock_byte2, 2); + CMP(vblock_byte1, 1); + CMP(vblock_byte0, 0); +#undef CMP + } + + const int block_id_mask = (1 << log_blocks_) - 1; + + for (int i = 0; i < num_hashes / unroll; ++i) { + __m256i vhash0 = + _mm256_loadu_si256(reinterpret_cast(hashes) + 4 * i + 0); + __m256i vhash1 = + _mm256_loadu_si256(reinterpret_cast(hashes) + 4 * i + 1); + __m256i vhash2 = + _mm256_loadu_si256(reinterpret_cast(hashes) + 4 * i + 2); + __m256i vhash3 = + _mm256_loadu_si256(reinterpret_cast(hashes) + 4 * i + 3); + + // We will get input in byte lanes in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, + // 18, 26, ...] + vhash0 = _mm256_or_si256(_mm256_srli_epi32(vhash0, 16), + _mm256_and_si256(vhash2, _mm256_set1_epi32(0xffff0000))); + vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16), + _mm256_and_si256(vhash3, _mm256_set1_epi32(0xffff0000))); + __m256i vstamp_A = _mm256_and_si256( + _mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_ - 7)), + _mm256_set1_epi16(0x7f)); + __m256i vstamp_B = _mm256_and_si256( + _mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_ - 7)), + _mm256_set1_epi16(0x7f)); + __m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8)); + __m256i vblock_id_A = + _mm256_and_si256(_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_)), + _mm256_set1_epi16(block_id_mask)); + __m256i vblock_id_B = + _mm256_and_si256(_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_)), + _mm256_set1_epi16(block_id_mask)); + __m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8)); + + // Visit all block bytes in reverse order (overwriting data on multiple matches) + __m256i vmatch_found = _mm256_setzero_si256(); + __m256i vslot_id = _mm256_shuffle_epi8(vslot_empty_or_end, vblock_id); + __m256i vgroup_id = _mm256_setzero_si256(); +#define CMP(VBLOCK_BYTE, VGROUPID_BYTE, BYTENUM) \ + { \ + __m256i vcmp = \ + _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \ + vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \ + vgroup_id = _mm256_blendv_epi8(vgroup_id, \ + _mm256_shuffle_epi8(VGROUPID_BYTE, vblock_id), vcmp); \ + vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM + 1), vcmp); \ + } + CMP(vblock_byte7, vgroupid_byte7, 7); + CMP(vblock_byte6, vgroupid_byte6, 6); + CMP(vblock_byte5, vgroupid_byte5, 5); + CMP(vblock_byte4, vgroupid_byte4, 4); + CMP(vblock_byte3, vgroupid_byte3, 3); + CMP(vblock_byte2, vgroupid_byte2, 2); + CMP(vblock_byte1, vgroupid_byte1, 1); + CMP(vblock_byte0, vgroupid_byte0, 0); +#undef CMP + + vslot_id = _mm256_add_epi8(vslot_id, _mm256_slli_epi32(vblock_id, 3)); + // So far the output is in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, ...] + vmatch_found = _mm256_shuffle_epi8( + vmatch_found, + _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15, + k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15)); + // Now it is: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, | 4, 5, 6, 7, + // 12, 13, 14, 15, ...] + vmatch_found = _mm256_permutevar8x32_epi32(vmatch_found, + _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + + reinterpret_cast(out_match_bitvector)[i] = + _mm256_movemask_epi8(vmatch_found); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 0, + _mm256_and_si256(vgroup_id, _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 1, + _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 8), _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 2, + _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 16), _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 3, + _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 24), _mm256_set1_epi32(0xff))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 0, + _mm256_and_si256(vslot_id, _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 1, + _mm256_and_si256(_mm256_srli_epi32(vslot_id, 8), _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 2, + _mm256_and_si256(_mm256_srli_epi32(vslot_id, 16), _mm256_set1_epi32(0xff))); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 3, + _mm256_and_si256(_mm256_srli_epi32(vslot_id, 24), _mm256_set1_epi32(0xff))); + } +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc new file mode 100644 index 00000000000..5f1c0776c56 --- /dev/null +++ b/cpp/src/arrow/compute/exec/util.cc @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/util.h" + +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using BitUtil::CountTrailingZeros; + +namespace util { + +inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = base_index + static_cast(CountTrailingZeros(word)); + word &= word - 1; + } + *num_indexes = n; +} + +inline void BitUtil::bits_filter_indexes_helper(uint64_t word, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = input_indexes[CountTrailingZeros(word)]; + word &= word - 1; + } + *num_indexes = n; +} + +template +void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + // 64 bits at a time + constexpr int unroll = 64; + int tail = num_bits % unroll; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + if (filter_input_indexes) { + bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, + num_indexes, indexes); + } else { + bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); + } + } else { +#endif + *num_indexes = 0; + for (int i = 0; i < num_bits / unroll; ++i) { + uint64_t word = reinterpret_cast(bits)[i]; + if (bit_to_search == 0) { + word = ~word; + } + if (filter_input_indexes) { + bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); + } else { + bits_to_indexes_helper(word, i * 64, num_indexes, indexes); + } + } +#if defined(ARROW_HAVE_AVX2) + } +#endif + // Optionally process the last partial word with masking out bits outside range + if (tail) { + uint64_t word = reinterpret_cast(bits)[num_bits / unroll]; + if (bit_to_search == 0) { + word = ~word; + } + word &= ~0ULL >> (64 - tail); + if (filter_input_indexes) { + bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, + indexes); + } else { + bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); + } + } +} + +void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, int* num_indexes, + uint16_t* indexes) { + if (bit_to_search == 0) { + bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); + } +} + +void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, int* num_indexes, + uint16_t* indexes) { + if (bit_to_search == 0) { + bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); + } +} + +void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, int* num_indexes_bit0, + uint16_t* indexes_bit0, uint16_t* indexes_bit1) { + bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0); + int num_indexes_bit1; + bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1); +} + +void BitUtil::bits_to_bytes_internal(const int num_bits, const uint8_t* bits, + uint8_t* bytes) { + constexpr int unroll = 8; + // Processing 8 bits at a time + for (int i = 0; i < (num_bits + unroll - 1) / unroll; ++i) { + uint8_t bits_next = bits[i]; + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart + // from the previous. + uint64_t unpacked = static_cast(bits_next & 0xfe) * + ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | + (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); + unpacked |= (bits_next & 1); + unpacked &= 0x0101010101010101ULL; + unpacked *= 255; + reinterpret_cast(bytes)[i] = unpacked; + } +} + +void BitUtil::bytes_to_bits_internal(const int num_bits, const uint8_t* bytes, + uint8_t* bits) { + constexpr int unroll = 8; + // Process 8 bits at a time + for (int i = 0; i < (num_bits + unroll - 1) / unroll; ++i) { + uint64_t bytes_next = reinterpret_cast(bytes)[i]; + bytes_next &= 0x0101010101010101ULL; + bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte + bits[i] = static_cast(bytes_next & 0xff); + } +} + +void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, uint8_t* bytes) { + int num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bits_to_bytes_avx2(num_processed, bits, bytes); + } +#endif + // Processing 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint8_t bits_next = bits[i]; + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart + // from the previous. + uint64_t unpacked = static_cast(bits_next & 0xfe) * + ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | + (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); + unpacked |= (bits_next & 1); + unpacked &= 0x0101010101010101ULL; + unpacked *= 255; + reinterpret_cast(bytes)[i] = unpacked; + } +} + +void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, + const uint8_t* bytes, uint8_t* bits) { + int num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bytes_to_bits_avx2(num_processed, bytes, bits); + } +#endif + // Process 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint64_t bytes_next = reinterpret_cast(bytes)[i]; + bytes_next &= 0x0101010101010101ULL; + bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte + bits[i] = static_cast(bytes_next & 0xff); + } +} + +bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, + uint32_t num_bytes) { +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + return are_all_bytes_zero_avx2(bytes, num_bytes); + } +#endif + uint64_t result_or = 0; + uint32_t i; + for (i = 0; i < num_bytes / 8; ++i) { + uint64_t x = reinterpret_cast(bytes)[i]; + result_or |= x; + } + if (num_bytes % 8 > 0) { + uint64_t tail = 0; + result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8); + } + return result_or == 0; +} + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h new file mode 100644 index 00000000000..d345bd3af0b --- /dev/null +++ b/cpp/src/arrow/compute/exec/util.h @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" + +#if defined(__clang__) || defined(__GNUC__) +#define BYTESWAP(x) __builtin_bswap64(x) +#define ROTL(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) +#elif defined(_MSC_VER) +#include +#define BYTESWAP(x) _byteswap_uint64(x) +#define ROTL(x, n) _rotl((x), (n)) +#endif + +namespace arrow { +namespace util { + +// Some platforms typedef int64_t as long int instead of long long int, +// which breaks the _mm256_i64gather_epi64 and _mm256_i32gather_epi64 intrinsics +// which need long long. +// We use the cast to the type below in these intrinsics to make the code +// compile in all cases. +// +using int64_for_gather_t = const long long int; // NOLINT runtime-int + +/// Storage used to allocate temporary vectors of a batch size. +/// Temporary vectors should resemble allocating temporary variables on the stack +/// but in the context of vectorized processing where we need to store a vector of +/// temporaries instead of a single value. +class TempVectorStack { + template + friend class TempVectorHolder; + + public: + Status Init(MemoryPool* pool, int64_t size) { + num_vectors_ = 0; + top_ = 0; + buffer_size_ = size; + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(size, pool)); + buffer_ = std::move(buffer); + return Status::OK(); + } + + private: + void alloc(uint32_t num_bytes, uint8_t** data, int* id) { + int64_t old_top = top_; + top_ += num_bytes + padding; + // Stack overflow check + ARROW_DCHECK(top_ <= buffer_size_); + *data = buffer_->mutable_data() + old_top; + *id = num_vectors_++; + } + void release(int id, uint32_t num_bytes) { + ARROW_DCHECK(num_vectors_ == id + 1); + int64_t size = num_bytes + padding; + ARROW_DCHECK(top_ >= size); + top_ -= size; + --num_vectors_; + } + static constexpr int64_t padding = 64; + int num_vectors_; + int64_t top_; + std::unique_ptr buffer_; + int64_t buffer_size_; +}; + +template +class TempVectorHolder { + friend class TempVectorStack; + + public: + ~TempVectorHolder() { stack_->release(id_, num_elements_ * sizeof(T)); } + T* mutable_data() { return reinterpret_cast(data_); } + TempVectorHolder(TempVectorStack* stack, uint32_t num_elements) { + stack_ = stack; + num_elements_ = num_elements; + stack_->alloc(num_elements * sizeof(T), &data_, &id_); + } + + private: + TempVectorStack* stack_; + uint8_t* data_; + int id_; + uint32_t num_elements_; +}; + +class BitUtil { + public: + static void bits_to_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, int* num_indexes, + uint16_t* indexes); + + static void bits_filter_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, int* num_indexes, + uint16_t* indexes); + + // Input and output indexes may be pointing to the same data (in-place filtering). + static void bits_split_indexes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, int* num_indexes_bit0, + uint16_t* indexes_bit0, uint16_t* indexes_bit1); + + // Bit 1 is replaced with byte 0xFF. + static void bits_to_bytes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, uint8_t* bytes); + // Return highest bit of each byte. + static void bytes_to_bits(int64_t hardware_flags, const int num_bits, + const uint8_t* bytes, uint8_t* bits); + + static bool are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, + uint32_t num_bytes); + + private: + inline static void bits_to_indexes_helper(uint64_t word, uint16_t base_index, + int* num_indexes, uint16_t* indexes); + inline static void bits_filter_indexes_helper(uint64_t word, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes); + template + static void bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes); + static void bits_to_bytes_internal(const int num_bits, const uint8_t* bits, + uint8_t* bytes); + static void bytes_to_bits_internal(const int num_bits, const uint8_t* bytes, + uint8_t* bits); + +#if defined(ARROW_HAVE_AVX2) + static void bits_to_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, int* num_indexes, + uint16_t* indexes); + static void bits_filter_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes); + template + static void bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits, + int* num_indexes, uint16_t* indexes); + template + static void bits_filter_indexes_imp_avx2(const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes); + static void bits_to_bytes_avx2(const int num_bits, const uint8_t* bits, uint8_t* bytes); + static void bytes_to_bits_avx2(const int num_bits, const uint8_t* bytes, uint8_t* bits); + static bool are_all_bytes_zero_avx2(const uint8_t* bytes, uint32_t num_bytes); +#endif +}; + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util_avx2.cc b/cpp/src/arrow/compute/exec/util_avx2.cc new file mode 100644 index 00000000000..8cf0104db46 --- /dev/null +++ b/cpp/src/arrow/compute/exec/util_avx2.cc @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace util { + +#if defined(ARROW_HAVE_AVX2) + +void BitUtil::bits_to_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, int* num_indexes, + uint16_t* indexes) { + if (bit_to_search == 0) { + bits_to_indexes_imp_avx2<0>(num_bits, bits, num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_imp_avx2<1>(num_bits, bits, num_indexes, indexes); + } +} + +template +void BitUtil::bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits, + int* num_indexes, uint16_t* indexes) { + // 64 bits at a time + constexpr int unroll = 64; + + // The caller takes care of processing the remaining bits at the end outside of the + // multiples of 64 + ARROW_DCHECK(num_bits % unroll == 0); + + constexpr uint64_t kEachByteIs1 = 0X0101010101010101ULL; + constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL; + constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; + + uint8_t byte_indexes[64]; + const uint64_t incr = kEachByteIs8; + const uint64_t mask = kByteSequence0To7; + *num_indexes = 0; + for (int i = 0; i < num_bits / unroll; ++i) { + uint64_t word = reinterpret_cast(bits)[i]; + if (bit_to_search == 0) { + word = ~word; + } + uint64_t base = 0; + int num_indexes_loop = 0; + while (word) { + uint64_t byte_indexes_next = + _pext_u64(mask, _pdep_u64(word, kEachByteIs1) * 0xff) + base; + *reinterpret_cast(byte_indexes + num_indexes_loop) = byte_indexes_next; + base += incr; + num_indexes_loop += static_cast(arrow::BitUtil::PopCount(word & 0xff)); + word >>= 8; + } + // Unpack indexes to 16-bits and either add the base of i * 64 or shuffle input + // indexes + for (int j = 0; j < (num_indexes_loop + 15) / 16; ++j) { + __m256i output = _mm256_cvtepi8_epi16( + _mm_loadu_si128(reinterpret_cast(byte_indexes) + j)); + output = _mm256_add_epi16(output, _mm256_set1_epi16(i * 64)); + _mm256_storeu_si256(((__m256i*)(indexes + *num_indexes)) + j, output); + } + *num_indexes += num_indexes_loop; + } +} + +void BitUtil::bits_filter_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + if (bit_to_search == 0) { + bits_filter_indexes_imp_avx2<0>(num_bits, bits, input_indexes, num_indexes, indexes); + } else { + bits_filter_indexes_imp_avx2<1>(num_bits, bits, input_indexes, num_indexes, indexes); + } +} + +template +void BitUtil::bits_filter_indexes_imp_avx2(const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, + int* out_num_indexes, uint16_t* indexes) { + // 64 bits at a time + constexpr int unroll = 64; + + // The caller takes care of processing the remaining bits at the end outside of the + // multiples of 64 + ARROW_DCHECK(num_bits % unroll == 0); + + constexpr uint64_t kRepeatedBitPattern0001 = 0x1111111111111111ULL; + constexpr uint64_t k4BitSequence0To15 = 0xfedcba9876543210ULL; + constexpr uint64_t kByteSequence_0_0_1_1_2_2_3_3 = 0x0303020201010000ULL; + constexpr uint64_t kByteSequence_4_4_5_5_6_6_7_7 = 0x0707060605050404ULL; + constexpr uint64_t kByteSequence_0_2_4_6_8_10_12_14 = 0x0e0c0a0806040200ULL; + constexpr uint64_t kByteSequence_1_3_5_7_9_11_13_15 = 0x0f0d0b0907050301ULL; + constexpr uint64_t kByteSequence_0_8_1_9_2_10_3_11 = 0x0b030a0209010800ULL; + constexpr uint64_t kByteSequence_4_12_5_13_6_14_7_15 = 0x0f070e060d050c04ULL; + + const uint64_t mask = k4BitSequence0To15; + int num_indexes = 0; + for (int i = 0; i < num_bits / unroll; ++i) { + uint64_t word = reinterpret_cast(bits)[i]; + if (bit_to_search == 0) { + word = ~word; + } + + int loop_id = 0; + while (word) { + uint64_t indexes_4bit = + _pext_u64(mask, _pdep_u64(word, kRepeatedBitPattern0001) * 0xf); + // Unpack 4 bit indexes to 8 bits + __m256i indexes_8bit = _mm256_set1_epi64x(indexes_4bit); + indexes_8bit = _mm256_shuffle_epi8( + indexes_8bit, + _mm256_setr_epi64x(kByteSequence_0_0_1_1_2_2_3_3, kByteSequence_4_4_5_5_6_6_7_7, + kByteSequence_0_0_1_1_2_2_3_3, + kByteSequence_4_4_5_5_6_6_7_7)); + indexes_8bit = _mm256_blendv_epi8( + _mm256_and_si256(indexes_8bit, _mm256_set1_epi8(0x0f)), + _mm256_and_si256(_mm256_srli_epi32(indexes_8bit, 4), _mm256_set1_epi8(0x0f)), + _mm256_set1_epi16(static_cast(0xff00))); + __m256i input = + _mm256_loadu_si256(((const __m256i*)input_indexes) + 4 * i + loop_id); + // Shuffle bytes to get low bytes in the first 128-bit lane and high bytes in the + // second + input = _mm256_shuffle_epi8( + input, _mm256_setr_epi64x( + kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15, + kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15)); + input = _mm256_permute4x64_epi64(input, 0xd8); // 0b11011000 + // Apply permutation + __m256i output = _mm256_shuffle_epi8(input, indexes_8bit); + // Move low and high bytes across 128-bit lanes to assemble back 16-bit indexes. + // (This is the reverse of the byte permutation we did on the input) + output = _mm256_permute4x64_epi64(output, + 0xd8); // The reverse of swapping 2nd and 3rd + // 64-bit element is the same permutation + output = _mm256_shuffle_epi8(output, + _mm256_setr_epi64x(kByteSequence_0_8_1_9_2_10_3_11, + kByteSequence_4_12_5_13_6_14_7_15, + kByteSequence_0_8_1_9_2_10_3_11, + kByteSequence_4_12_5_13_6_14_7_15)); + _mm256_storeu_si256((__m256i*)(indexes + num_indexes), output); + num_indexes += static_cast(arrow::BitUtil::PopCount(word & 0xffff)); + word >>= 16; + ++loop_id; + } + } + + *out_num_indexes = num_indexes; +} + +void BitUtil::bits_to_bytes_avx2(const int num_bits, const uint8_t* bits, + uint8_t* bytes) { + constexpr int unroll = 32; + + constexpr uint64_t kEachByteIs1 = 0x0101010101010101ULL; + constexpr uint64_t kEachByteIs2 = 0x0202020202020202ULL; + constexpr uint64_t kEachByteIs3 = 0x0303030303030303ULL; + constexpr uint64_t kByteSequencePowersOf2 = 0x8040201008040201ULL; + + // Processing 32 bits at a time + for (int i = 0; i < num_bits / unroll; ++i) { + __m256i unpacked = _mm256_set1_epi32(reinterpret_cast(bits)[i]); + unpacked = _mm256_shuffle_epi8( + unpacked, _mm256_setr_epi64x(0ULL, kEachByteIs1, kEachByteIs2, kEachByteIs3)); + __m256i bits_in_bytes = _mm256_set1_epi64x(kByteSequencePowersOf2); + unpacked = + _mm256_cmpeq_epi8(bits_in_bytes, _mm256_and_si256(unpacked, bits_in_bytes)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes) + i, unpacked); + } +} + +void BitUtil::bytes_to_bits_avx2(const int num_bits, const uint8_t* bytes, + uint8_t* bits) { + constexpr int unroll = 32; + // Processing 32 bits at a time + for (int i = 0; i < num_bits / unroll; ++i) { + reinterpret_cast(bits)[i] = _mm256_movemask_epi8( + _mm256_loadu_si256(reinterpret_cast(bytes) + i)); + } +} + +bool BitUtil::are_all_bytes_zero_avx2(const uint8_t* bytes, uint32_t num_bytes) { + __m256i result_or = _mm256_setzero_si256(); + uint32_t i; + for (i = 0; i < num_bytes / 32; ++i) { + __m256i x = _mm256_loadu_si256(reinterpret_cast(bytes) + i); + result_or = _mm256_or_si256(result_or, x); + } + uint32_t result_or32 = _mm256_movemask_epi8(result_or); + if (num_bytes % 32 > 0) { + uint64_t tail[4] = {0, 0, 0, 0}; + result_or32 |= memcmp(bytes + i * 32, tail, num_bytes % 32); + } + return result_or32 == 0; +} + +#endif // ARROW_HAVE_AVX2 + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index ae7bf9324db..0e5c8ace53f 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/api_aggregate.h" - #include #include #include @@ -24,7 +22,13 @@ #include #include "arrow/buffer_builder.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" +#include "arrow/compute/exec/key_compare.h" +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/key_hash.h" +#include "arrow/compute/exec/key_map.h" +#include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/kernel.h" #include "arrow/compute/kernels/aggregate_internal.h" @@ -33,6 +37,7 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" #include "arrow/util/make_unique.h" #include "arrow/visitor_inline.h" @@ -436,6 +441,297 @@ struct GrouperImpl : Grouper { std::vector> encoders_; }; +struct GrouperFastImpl : Grouper { + static bool CanUse(const std::vector& keys) { +#if ARROW_LITTLE_ENDIAN + for (size_t i = 0; i < keys.size(); ++i) { + const auto& key = keys[i].type; + if (is_large_binary_like(key->id())) { + return false; + } + } + return true; +#else + return false; +#endif + } + + static Result> Make( + const std::vector& keys, ExecContext* ctx) { + auto impl = ::arrow::internal::make_unique(); + impl->ctx_ = ctx; + + RETURN_NOT_OK(impl->temp_stack_.Init(ctx->memory_pool(), 64 * minibatch_size_max_)); + impl->encode_ctx_.hardware_flags = + arrow::internal::CpuInfo::GetInstance()->hardware_flags(); + impl->encode_ctx_.stack = &impl->temp_stack_; + + auto num_columns = keys.size(); + impl->col_metadata_.resize(num_columns); + impl->key_types_.resize(num_columns); + impl->dictionaries_.resize(num_columns); + for (size_t icol = 0; icol < num_columns; ++icol) { + const auto& key = keys[icol].type; + if (key->id() == Type::DICTIONARY) { + auto bit_width = checked_cast(*key).bit_width(); + ARROW_DCHECK(bit_width % 8 == 0); + impl->col_metadata_[icol] = + arrow::compute::KeyEncoder::KeyColumnMetadata(true, bit_width / 8); + } else if (key->id() == Type::BOOL) { + impl->col_metadata_[icol] = + arrow::compute::KeyEncoder::KeyColumnMetadata(true, 0); + } else if (is_fixed_width(key->id())) { + impl->col_metadata_[icol] = arrow::compute::KeyEncoder::KeyColumnMetadata( + true, checked_cast(*key).bit_width() / 8); + } else if (is_binary_like(key->id())) { + impl->col_metadata_[icol] = + arrow::compute::KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t)); + } else { + return Status::NotImplemented("Keys of type ", *key); + } + impl->key_types_[icol] = key; + } + + impl->encoder_.Init(impl->col_metadata_, &impl->encode_ctx_, + /* row_alignment = */ sizeof(uint64_t), + /* string_alignment = */ sizeof(uint64_t)); + RETURN_NOT_OK(impl->rows_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); + RETURN_NOT_OK( + impl->rows_minibatch_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); + impl->minibatch_size_ = impl->minibatch_size_min_; + GrouperFastImpl* impl_ptr = impl.get(); + auto equal_func = [impl_ptr]( + int num_keys_to_compare, const uint16_t* selection_may_be_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch) { + arrow::compute::KeyCompare::CompareRows( + num_keys_to_compare, selection_may_be_null, group_ids, &impl_ptr->encode_ctx_, + out_num_keys_mismatch, out_selection_mismatch, impl_ptr->rows_minibatch_, + impl_ptr->rows_); + }; + auto append_func = [impl_ptr](int num_keys, const uint16_t* selection) { + return impl_ptr->rows_.AppendSelectionFrom(impl_ptr->rows_minibatch_, num_keys, + selection); + }; + RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool(), + impl->encode_ctx_.stack, impl->log_minibatch_max_, + equal_func, append_func)); + impl->cols_.resize(num_columns); + constexpr int padding_for_SIMD = 32; + impl->minibatch_hashes_.resize(impl->minibatch_size_max_ + + padding_for_SIMD / sizeof(uint32_t)); + + return std::move(impl); + } + + ~GrouperFastImpl() { map_.cleanup(); } + + Result Consume(const ExecBatch& batch) override { + int64_t num_rows = batch.length; + int num_columns = batch.num_values(); + + // Process dictionaries + for (int icol = 0; icol < num_columns; ++icol) { + if (key_types_[icol]->id() == Type::DICTIONARY) { + auto data = batch[icol].array(); + auto dict = MakeArray(data->dictionary); + if (dictionaries_[icol]) { + if (!dictionaries_[icol]->Equals(dict)) { + // TODO(bkietz) unify if necessary. For now, just error if any batch's + // dictionary differs from the first we saw for this key + return Status::NotImplemented("Unifying differing dictionaries"); + } + } else { + dictionaries_[icol] = std::move(dict); + } + } + } + + std::shared_ptr group_ids; + ARROW_ASSIGN_OR_RAISE( + group_ids, AllocateBuffer(sizeof(uint32_t) * num_rows, ctx_->memory_pool())); + + for (int icol = 0; icol < num_columns; ++icol) { + const uint8_t* non_nulls = nullptr; + if (batch[icol].array()->buffers[0] != NULLPTR) { + non_nulls = batch[icol].array()->buffers[0]->data(); + } + const uint8_t* fixedlen = batch[icol].array()->buffers[1]->data(); + const uint8_t* varlen = nullptr; + if (!col_metadata_[icol].is_fixed_length) { + varlen = batch[icol].array()->buffers[2]->data(); + } + + cols_[icol] = arrow::compute::KeyEncoder::KeyColumnArray( + col_metadata_[icol], num_rows, non_nulls, fixedlen, varlen); + } + + // Split into smaller mini-batches + // + for (uint32_t start_row = 0; start_row < num_rows;) { + uint32_t batch_size_next = std::min(static_cast(minibatch_size_), + static_cast(num_rows) - start_row); + + // Encode + rows_minibatch_.Clean(); + RETURN_NOT_OK(encoder_.PrepareOutputForEncode(start_row, batch_size_next, + &rows_minibatch_, cols_)); + encoder_.Encode(start_row, batch_size_next, &rows_minibatch_, cols_); + + // Compute hash + if (encoder_.row_metadata().is_fixed_length) { + Hashing::hash_fixed(encode_ctx_.hardware_flags, batch_size_next, + encoder_.row_metadata().fixed_length, rows_minibatch_.data(1), + minibatch_hashes_.data()); + } else { + auto hash_temp_buf = + util::TempVectorHolder(&temp_stack_, 4 * batch_size_next); + Hashing::hash_varlen(encode_ctx_.hardware_flags, batch_size_next, + rows_minibatch_.offsets(), rows_minibatch_.data(2), + hash_temp_buf.mutable_data(), minibatch_hashes_.data()); + } + + // Map + RETURN_NOT_OK( + map_.map(batch_size_next, minibatch_hashes_.data(), + reinterpret_cast(group_ids->mutable_data()) + start_row)); + + start_row += batch_size_next; + + if (minibatch_size_ * 2 <= minibatch_size_max_) { + minibatch_size_ *= 2; + } + } + + return Datum(UInt32Array(batch.length, std::move(group_ids))); + } + + uint32_t num_groups() const override { return static_cast(rows_.length()); } + + Result GetUniques() override { + auto num_columns = static_cast(col_metadata_.size()); + int64_t num_groups = rows_.length(); + + std::vector> non_null_bufs(num_columns); + std::vector> fixedlen_bufs(num_columns); + std::vector> varlen_bufs(num_columns); + + constexpr int padding_bits = 64; + constexpr int padding_for_SIMD = 32; + for (size_t i = 0; i < num_columns; ++i) { + ARROW_ASSIGN_OR_RAISE(non_null_bufs[i], AllocateBitmap(num_groups + padding_bits, + ctx_->memory_pool())); + if (col_metadata_[i].is_fixed_length) { + if (col_metadata_[i].fixed_length == 0) { + ARROW_ASSIGN_OR_RAISE( + fixedlen_bufs[i], + AllocateBitmap(num_groups + padding_bits, ctx_->memory_pool())); + } else { + ARROW_ASSIGN_OR_RAISE( + fixedlen_bufs[i], + AllocateBuffer( + num_groups * col_metadata_[i].fixed_length + padding_for_SIMD, + ctx_->memory_pool())); + } + } else { + ARROW_ASSIGN_OR_RAISE( + fixedlen_bufs[i], + AllocateBuffer((num_groups + 1) * sizeof(uint32_t) + padding_for_SIMD, + ctx_->memory_pool())); + } + cols_[i] = arrow::compute::KeyEncoder::KeyColumnArray( + col_metadata_[i], num_groups, non_null_bufs[i]->mutable_data(), + fixedlen_bufs[i]->mutable_data(), nullptr); + } + + for (int64_t start_row = 0; start_row < num_groups;) { + int64_t batch_size_next = + std::min(num_groups - start_row, static_cast(minibatch_size_max_)); + encoder_.DecodeFixedLengthBuffers(start_row, start_row, batch_size_next, rows_, + &cols_); + start_row += batch_size_next; + } + + if (!rows_.metadata().is_fixed_length) { + for (size_t i = 0; i < num_columns; ++i) { + if (!col_metadata_[i].is_fixed_length) { + auto varlen_size = + reinterpret_cast(fixedlen_bufs[i]->data())[num_groups]; + ARROW_ASSIGN_OR_RAISE( + varlen_bufs[i], + AllocateBuffer(varlen_size + padding_for_SIMD, ctx_->memory_pool())); + cols_[i] = arrow::compute::KeyEncoder::KeyColumnArray( + col_metadata_[i], num_groups, non_null_bufs[i]->mutable_data(), + fixedlen_bufs[i]->mutable_data(), varlen_bufs[i]->mutable_data()); + } + } + + for (int64_t start_row = 0; start_row < num_groups;) { + int64_t batch_size_next = + std::min(num_groups - start_row, static_cast(minibatch_size_max_)); + encoder_.DecodeVaryingLengthBuffers(start_row, start_row, batch_size_next, rows_, + &cols_); + start_row += batch_size_next; + } + } + + ExecBatch out({}, num_groups); + out.values.resize(num_columns); + for (size_t i = 0; i < num_columns; ++i) { + auto valid_count = arrow::internal::CountSetBits( + non_null_bufs[i]->data(), /*offset=*/0, static_cast(num_groups)); + int null_count = static_cast(num_groups) - static_cast(valid_count); + + if (col_metadata_[i].is_fixed_length) { + out.values[i] = ArrayData::Make( + key_types_[i], num_groups, + {std::move(non_null_bufs[i]), std::move(fixedlen_bufs[i])}, null_count); + } else { + out.values[i] = + ArrayData::Make(key_types_[i], num_groups, + {std::move(non_null_bufs[i]), std::move(fixedlen_bufs[i]), + std::move(varlen_bufs[i])}, + null_count); + } + } + + // Process dictionaries + for (size_t icol = 0; icol < num_columns; ++icol) { + if (key_types_[icol]->id() == Type::DICTIONARY) { + if (dictionaries_[icol]) { + out.values[icol].array()->dictionary = dictionaries_[icol]->data(); + } else { + ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(key_types_[icol], 0)); + out.values[icol].array()->dictionary = dict->data(); + } + } + } + + return out; + } + + static constexpr int log_minibatch_max_ = 10; + static constexpr int minibatch_size_max_ = 1 << log_minibatch_max_; + static constexpr int minibatch_size_min_ = 128; + int minibatch_size_; + + ExecContext* ctx_; + arrow::util::TempVectorStack temp_stack_; + arrow::compute::KeyEncoder::KeyEncoderContext encode_ctx_; + + std::vector> key_types_; + std::vector col_metadata_; + std::vector cols_; + std::vector minibatch_hashes_; + + std::vector> dictionaries_; + + arrow::compute::KeyEncoder::KeyRowArray rows_; + arrow::compute::KeyEncoder::KeyRowArray rows_minibatch_; + arrow::compute::KeyEncoder encoder_; + arrow::compute::SwissTable map_; +}; + /// C++ abstract base class for the HashAggregateKernel interface. /// Implementations should be default constructible and perform initialization in /// Init(). @@ -884,6 +1180,9 @@ Result ResolveKernels( Result> Grouper::Make(const std::vector& descrs, ExecContext* ctx) { + if (GrouperFastImpl::CanUse(descrs)) { + return GrouperFastImpl::Make(descrs, ctx); + } return GrouperImpl::Make(descrs, ctx); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 507f1716110..a0d2fd208a9 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include #include @@ -22,8 +24,6 @@ #include #include -#include - #include "arrow/array.h" #include "arrow/chunked_array.h" #include "arrow/compute/api_aggregate.h" @@ -182,10 +182,52 @@ struct TestGrouper { ExpectConsume(*ExecBatch::Make(key_batch), expected); } + void AssertEquivalentIds(const Datum& expected, const Datum& actual) { + auto left = expected.make_array(); + auto right = actual.make_array(); + ASSERT_EQ(left->length(), right->length()) << "#ids unequal"; + int64_t num_ids = left->length(); + auto left_data = left->data(); + auto right_data = right->data(); + const uint32_t* left_ids = + reinterpret_cast(left_data->buffers[1]->data()); + const uint32_t* right_ids = + reinterpret_cast(right_data->buffers[1]->data()); + uint32_t max_left_id = 0; + uint32_t max_right_id = 0; + for (int64_t i = 0; i < num_ids; ++i) { + if (left_ids[i] > max_left_id) { + max_left_id = left_ids[i]; + } + if (right_ids[i] > max_right_id) { + max_right_id = right_ids[i]; + } + } + std::vector right_to_left_present(max_right_id + 1, false); + std::vector left_to_right_present(max_left_id + 1, false); + std::vector right_to_left(max_right_id + 1); + std::vector left_to_right(max_left_id + 1); + for (int64_t i = 0; i < num_ids; ++i) { + uint32_t left_id = left_ids[i]; + uint32_t right_id = right_ids[i]; + if (!left_to_right_present[left_id]) { + left_to_right[left_id] = right_id; + left_to_right_present[left_id] = true; + } + if (!right_to_left_present[right_id]) { + right_to_left[right_id] = left_id; + right_to_left_present[right_id] = true; + } + ASSERT_EQ(left_id, right_to_left[right_id]); + ASSERT_EQ(right_id, left_to_right[left_id]); + } + } + void ExpectConsume(const ExecBatch& key_batch, Datum expected) { Datum ids; ConsumeAndValidate(key_batch, &ids); - AssertDatumsEqual(expected, ids, /*verbose=*/true); + AssertEquivalentIds(expected, ids); + // AssertDatumsEqual(expected, ids, /*verbose=*/true); } void ConsumeAndValidate(const ExecBatch& key_batch, Datum* ids = nullptr) { diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 1c776f18329..7a7ffcff229 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -85,16 +85,23 @@ class TestPartitioning : public ::testing::Test { const std::vector& expected_expressions) { ASSERT_OK_AND_ASSIGN(auto partition_results, partitioning->Partition(full_batch)); std::shared_ptr rest = full_batch; + ASSERT_EQ(partition_results.batches.size(), expected_batches.size()); - auto max_index = std::min(partition_results.batches.size(), expected_batches.size()); - for (std::size_t partition_index = 0; partition_index < max_index; - partition_index++) { - std::shared_ptr actual_batch = - partition_results.batches[partition_index]; - AssertBatchesEqual(*expected_batches[partition_index], *actual_batch); - compute::Expression actual_expression = - partition_results.expressions[partition_index]; - ASSERT_EQ(expected_expressions[partition_index], actual_expression); + + for (size_t i = 0; i < partition_results.batches.size(); i++) { + std::shared_ptr actual_batch = partition_results.batches[i]; + compute::Expression actual_expression = partition_results.expressions[i]; + + auto expected_expression = std::find(expected_expressions.begin(), + expected_expressions.end(), actual_expression); + ASSERT_NE(expected_expression, expected_expressions.end()) + << "Unexpected partition expr " << actual_expression.ToString(); + + auto expected_batch = + expected_batches[expected_expression - expected_expressions.begin()]; + + SCOPED_TRACE("Batch for " + expected_expression->ToString()); + AssertBatchesEqual(*expected_batch, *actual_batch); } }