diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index dde66ac79c4..e024483f665 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -167,6 +167,11 @@ struct ARROW_EXPORT ArrayData { std::shared_ptr Copy() const { return std::make_shared(*this); } + bool IsNull(int64_t i) const { + return ((buffers[0] != NULLPTR) ? !bit_util::GetBit(buffers[0]->data(), i + offset) + : null_count.load() == length); + } + // Access a buffer's data as a typed C pointer template inline const T* GetValues(int i, int64_t absolute_offset) const { @@ -324,18 +329,14 @@ struct ARROW_EXPORT ArraySpan { return GetValues(i, this->offset); } - bool IsNull(int64_t i) const { - return ((this->buffers[0].data != NULLPTR) - ? !bit_util::GetBit(this->buffers[0].data, i + this->offset) - : this->null_count == this->length); - } - - bool IsValid(int64_t i) const { + inline bool IsValid(int64_t i) const { return ((this->buffers[0].data != NULLPTR) ? bit_util::GetBit(this->buffers[0].data, i + this->offset) : this->null_count != this->length); } + inline bool IsNull(int64_t i) const { return !IsValid(i); } + std::shared_ptr ToArrayData() const; std::shared_ptr ToArray() const; diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 543a4ece575..7d8abc0ba4c 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -109,7 +109,7 @@ static void TableJoinOverhead(benchmark::State& state, static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, kKeyCol, tolerance); + AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 3da612aa03e..869456a5775 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -17,34 +17,63 @@ #include #include -#include #include #include +#include +#include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/compute/light_array.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" #include "arrow/util/optional.h" +#include "arrow/util/string_view.h" namespace arrow { namespace compute { -// Remove this when multiple keys and/or types is supported -typedef int32_t KeyType; +template +inline typename T::const_iterator std_find(const T& container, const V& val) { + return std::find(container.begin(), container.end(), val); +} + +template +inline bool std_has(const T& container, const V& val) { + return container.end() != std_find(container, val); +} + +typedef uint64_t ByType; +typedef uint64_t OnType; +typedef uint64_t HashType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; +// normalize the value to 64-bits while preserving ordering of values +template ::value, bool> = true> +static inline uint64_t time_value(T t) { + uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; + return t < 0 ? static_cast(t + bias) : static_cast(t); +} + +// indicates normalization of a key value +template ::value, bool> = true> +static inline uint64_t key_value(T t) { + return static_cast(t); +} + /** * Simple implementation for an unbound concurrent queue */ @@ -65,6 +94,11 @@ class ConcurrentQueue { cond_.notify_one(); } + void Clear() { + std::unique_lock lock(mutex_); + queue_ = std::queue(); + } + util::optional TryPop() { // Try to pop the oldest value from the queue (or return nullopt if none) std::unique_lock lock(mutex_); @@ -99,7 +133,7 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - int64_t time; + OnType time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) @@ -109,10 +143,10 @@ struct MemoStore { row_index_t row; }; - std::unordered_map entries_; + std::unordered_map entries_; - void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, - KeyType key) { + void Store(const std::shared_ptr& batch, row_index_t row, OnType time, + ByType key) { auto& e = entries_[key]; // that we can do this assignment optionally, is why we // can get array with using shared_ptr above (the batch @@ -122,13 +156,13 @@ struct MemoStore { e.time = time; } - util::optional GetEntryForKey(KeyType key) const { + util::optional GetEntryForKey(ByType key) const { auto e = entries_.find(key); if (entries_.end() == e) return util::nullopt; return util::optional(&e->second); } - void RemoveEntriesWithLesserTime(int64_t ts) { + void RemoveEntriesWithLesserTime(OnType ts) { for (auto e = entries_.begin(); e != entries_.end();) if (e->second.time < ts) e = entries_.erase(e); @@ -137,18 +171,89 @@ struct MemoStore { } }; +// a specialized higher-performance variation of Hashing64 logic from hash_join_node +// the code here avoids recreating objects that are independent of each batch processed +class KeyHasher { + static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; + + public: + explicit KeyHasher(const std::vector& indices) + : indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); + } + + Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); + } + + const std::vector& HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; + } + batch_ = NULLPTR; // invalidate cached hashes for batch + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + batch_ = batch; + return hashes_; + } + + private: + std::vector indices_; + std::vector metadata_; + const RecordBatch* batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + util::TempVectorStack stack_; +}; + class InputState { // InputState correponds to an input // Input record batches are queued up in InputState until processed and // turned into output record batches. public: - InputState(const std::shared_ptr& schema, - const std::string& time_col_name, const std::string& key_col_name) + InputState(bool must_hash, bool may_rehash, KeyHasher* key_hasher, + const std::shared_ptr& schema, + const col_index_t time_col_index, + const std::vector& key_col_index) : queue_(), schema_(schema), - time_col_index_(schema->GetFieldIndex(time_col_name)), - key_col_index_(schema->GetFieldIndex(key_col_name)) {} + time_col_index_(time_col_index), + key_col_index_(key_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(key_col_index.size()), + key_hasher_(key_hasher), + must_hash_(must_hash), + may_rehash_(may_rehash) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } + } col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); @@ -164,7 +269,7 @@ class InputState { bool IsTimeOrKeyColumn(col_index_t i) const { DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || (i == key_col_index_); + return (i == time_col_index_) || std_has(key_col_index_, i); } // Gets the latest row index, assuming the queue isn't empty @@ -184,27 +289,87 @@ class InputState { return queue_.UnsyncFront(); } - KeyType GetLatestKey() const { - return queue_.UnsyncFront() - ->column_data(key_col_index_) - ->GetValues(1)[latest_ref_row_]; +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[row]); \ + } + + inline ByType GetLatestKey() const { + return GetLatestKey(queue_.UnsyncFront().get(), latest_ref_row_); } - int64_t GetLatestTime() const { - return queue_.UnsyncFront() - ->column_data(time_col_index_) - ->GetValues(1)[latest_ref_row_]; + inline ByType GetLatestKey(const RecordBatch* batch, row_index_t row) const { + if (must_hash_) { + return key_hasher_->HashesFor(batch)[row]; + } + if (key_col_index_.size() == 0) { + return 0; + } + auto data = batch->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + LATEST_VAL_CASE(DATE32, key_value) + LATEST_VAL_CASE(DATE64, key_value) + LATEST_VAL_CASE(TIME32, key_value) + LATEST_VAL_CASE(TIME64, key_value) + LATEST_VAL_CASE(TIMESTAMP, key_value) + default: + DCHECK(false); + return 0; // cannot happen + } } + inline OnType GetLatestTime() const { + return GetLatestTime(queue_.UnsyncFront().get(), latest_ref_row_); + } + + inline ByType GetLatestTime(const RecordBatch* batch, row_index_t row) const { + auto data = batch->column_data(time_col_index_); + switch (time_type_id_) { + LATEST_VAL_CASE(INT8, time_value) + LATEST_VAL_CASE(INT16, time_value) + LATEST_VAL_CASE(INT32, time_value) + LATEST_VAL_CASE(INT64, time_value) + LATEST_VAL_CASE(UINT8, time_value) + LATEST_VAL_CASE(UINT16, time_value) + LATEST_VAL_CASE(UINT32, time_value) + LATEST_VAL_CASE(UINT64, time_value) + LATEST_VAL_CASE(DATE32, time_value) + LATEST_VAL_CASE(DATE64, time_value) + LATEST_VAL_CASE(TIME32, time_value) + LATEST_VAL_CASE(TIME64, time_value) + LATEST_VAL_CASE(TIMESTAMP, time_value) + default: + DCHECK(false); + return 0; // cannot happen + } + } + +#undef LATEST_VAL_CASE + bool Finished() const { return batches_processed_ == total_batches_; } - bool Advance() { + Result Advance() { // Try advancing to the next row and update latest_ref_row_ // Returns true if able to advance, false if not. bool have_active_batch = (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); if (have_active_batch) { + OnType next_time = GetLatestTime(); + if (latest_time_ > next_time) { + return Status::Invalid("AsofJoin does not allow out-of-order on-key values"); + } + latest_time_ = next_time; // If we have an active batch if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. @@ -222,46 +387,60 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(int64_t ts) { + Result AdvanceAndMemoize(OnType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. // Check if already updated for TS (or if there is no latest) if (Empty()) return false; // can't advance if empty - auto latest_time = GetLatestTime(); - if (latest_time > ts) return false; // already advanced // Not updated. Try to update and possibly advance. - bool updated = false; + bool advanced, updated = false; do { - latest_time = GetLatestTime(); + auto latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. - if (latest_time <= ts) { - memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); - } else { + if (latest_time > ts) { break; // hit a future timestamp -- done updating for now } + auto rb = GetLatestBatch(); + if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + must_hash_ = true; + may_rehash_ = false; + Rehash(); + } + memo_.Store(rb, latest_ref_row_, latest_time, GetLatestKey()); updated = true; - } while (Advance()); + ARROW_ASSIGN_OR_RAISE(advanced, Advance()); + } while (advanced); return updated; } - void Push(const std::shared_ptr& rb) { + void Rehash() { + MemoStore new_memo; + for (const auto& entry : memo_.entries_) { + const auto& e = entry.second; + new_memo.Store(e.batch, e.row, e.time, GetLatestKey(e.batch.get(), e.row)); + } + memo_ = new_memo; + } + + Status Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { queue_.Push(rb); } else { ++batches_processed_; // don't enqueue empty batches, just record as processed } + return Status::OK(); } - util::optional GetMemoEntryForKey(KeyType key) { + util::optional GetMemoEntryForKey(ByType key) { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(KeyType key) { + util::optional GetMemoTimeForKey(ByType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; @@ -270,7 +449,7 @@ class InputState { } } - void RemoveMemoEntriesWithLesserTime(int64_t ts) { + void RemoveMemoEntriesWithLesserTime(OnType ts) { memo_.RemoveEntriesWithLesserTime(ts); } @@ -294,10 +473,22 @@ class InputState { // Index of the time col col_index_t time_col_index_; // Index of the key col - col_index_t key_col_index_; + std::vector key_col_index_; + // Type id of the time column + Type::type time_type_id_; + // Type id of the key column + std::vector key_type_id_; + // Hasher for key elements + mutable KeyHasher* key_hasher_; + // True if hashing is mandatory + bool must_hash_; + // True if by-key values may be rehashed + bool may_rehash_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; + // Time of latest row + OnType latest_time_ = std::numeric_limits::lowest(); // Stores latest known values for the various keys MemoStore memo_; // Mapping of source columns to destination columns @@ -336,18 +527,18 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, int64_t tolerance) { + void Emplace(std::vector>& in, OnType tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key - KeyType key = in[0]->GetLatestKey(); + ByType key = in[0]->GetLatestKey(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); - int64_t lhs_latest_time = in[0]->GetLatestTime(); + OnType lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -407,29 +598,42 @@ class CompositeReferenceTable { DCHECK_EQ(src_field->name(), dst_field->name()); const auto& field_type = src_field->type(); - if (field_type->Equals(arrow::int32())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::int64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float32())) { - ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else { - ARROW_RETURN_NOT_OK( - Status::Invalid("Unsupported data type: ", src_field->name())); +#define ASOFJOIN_MATERIALIZE_CASE(id) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + ARROW_ASSIGN_OR_RAISE( \ + arrays.at(i_dst_col), \ + MaterializeColumn(memory_pool, field_type, i_table, i_src_col)); \ + break; \ + } + + switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(INT8) + ASOFJOIN_MATERIALIZE_CASE(INT16) + ASOFJOIN_MATERIALIZE_CASE(INT32) + ASOFJOIN_MATERIALIZE_CASE(INT64) + ASOFJOIN_MATERIALIZE_CASE(UINT8) + ASOFJOIN_MATERIALIZE_CASE(UINT16) + ASOFJOIN_MATERIALIZE_CASE(UINT32) + ASOFJOIN_MATERIALIZE_CASE(UINT64) + ASOFJOIN_MATERIALIZE_CASE(FLOAT) + ASOFJOIN_MATERIALIZE_CASE(DOUBLE) + ASOFJOIN_MATERIALIZE_CASE(DATE32) + ASOFJOIN_MATERIALIZE_CASE(DATE64) + ASOFJOIN_MATERIALIZE_CASE(TIME32) + ASOFJOIN_MATERIALIZE_CASE(TIME64) + ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) + ASOFJOIN_MATERIALIZE_CASE(STRING) + ASOFJOIN_MATERIALIZE_CASE(LARGE_STRING) + ASOFJOIN_MATERIALIZE_CASE(BINARY) + ASOFJOIN_MATERIALIZE_CASE(LARGE_BINARY) + default: + return Status::Invalid("Unsupported data type ", + src_field->type()->ToString(), " for field ", + src_field->name()); } + +#undef ASOFJOIN_MATERIALIZE_CASE } } } @@ -459,17 +663,45 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template - Result> MaterializePrimitiveColumn(MemoryPool* memory_pool, - size_t i_table, - col_index_t i_col) { - Builder builder(memory_pool); + template ::BuilderType> + enable_if_fixed_width_type static BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); + return Status::OK(); + } + + template ::BuilderType> + enable_if_base_binary static BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (source->IsNull(row)) { + return builder.AppendNull(); + } + using offset_type = typename Type::offset_type; + const uint8_t* data = source->buffers[2]->data(); + const offset_type* offsets = source->GetValues(1); + const offset_type offset0 = offsets[row]; + const offset_type offset1 = offsets[row + 1]; + return builder.Append(data + offset0, offset1 - offset0); + } + + template ::BuilderType> + Result> MaterializeColumn(MemoryPool* memory_pool, + const std::shared_ptr& type, + size_t i_table, col_index_t i_col) { + ARROW_ASSIGN_OR_RAISE(auto a_builder, MakeBuilder(type, memory_pool)); + Builder& builder = *checked_cast(a_builder.get()); ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { - builder.UnsafeAppend( - ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); + Status st = + BuilderAppend(builder, ref.batch->column_data(i_col), ref.row); + ARROW_RETURN_NOT_OK(st); } else { builder.UnsafeAppendNull(); } @@ -480,14 +712,21 @@ class CompositeReferenceTable { } }; +// TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a non-negligible +// probability of collision, which can cause incorrect results when many different by-key +// values are processed. Thus, AsofJoinNode is currently limited to about 100k by-keys for +// guaranteeing this probability is below 1 in a billion. The fix is 128-bit hashing. +// See ARROW-17653 class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp - bool UpdateRhs() { + Result UpdateRhs() { auto& lhs = *state_.at(0); auto lhs_latest_time = lhs.GetLatestTime(); bool any_updated = false; - for (size_t i = 1; i < state_.size(); ++i) - any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time); + for (size_t i = 1; i < state_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(bool advanced, state_[i]->AdvanceAndMemoize(lhs_latest_time)); + any_updated |= advanced; + } return any_updated; } @@ -495,7 +734,7 @@ class AsofJoinNode : public ExecNode { bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; if (lhs.Empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts = lhs.GetLatestTime(); + OnType lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { @@ -523,7 +762,7 @@ class AsofJoinNode : public ExecNode { if (lhs.Finished() || lhs.Empty()) break; // Advance each of the RHS as far as possible to be up to date for the LHS timestamp - bool any_rhs_advanced = UpdateRhs(); + ARROW_ASSIGN_OR_RAISE(bool any_rhs_advanced, UpdateRhs()); // If we have received enough inputs to produce the next output batch // (decided by IsUpToDateWithLhsRow), we will perform the join and @@ -531,8 +770,9 @@ class AsofJoinNode : public ExecNode { // the LHS and adding joined row to rows_ (done by Emplace). Finally, // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { - dst.Emplace(state_, options_.tolerance); - if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch + dst.Emplace(state_, tolerance_); + ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); + if (!advanced) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data } @@ -541,8 +781,7 @@ class AsofJoinNode : public ExecNode { // Prune memo entries that have expired (to bound memory consumption) if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { - state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - tolerance_); } } @@ -572,7 +811,6 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - StopProducing(); ErrorIfNotOk(result.status()); return; } @@ -584,8 +822,8 @@ class AsofJoinNode : public ExecNode { // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) if (state_.at(0)->Finished()) { - StopProducing(); outputs_[0]->InputFinished(this, batches_produced_); + finished_.MarkFinished(); } } @@ -602,54 +840,172 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema); + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, bool must_hash, + bool may_rehash); + + Status Init() override { + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema())); + state_.push_back(::arrow::internal::make_unique( + must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(), + indices_of_on_key_[i], indices_of_by_key_[i])); + } + + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); + } virtual ~AsofJoinNode() { process_.Push(false); // poison pill process_thread_.join(); } + const std::vector& indices_of_on_key() { return indices_of_on_key_; } + const std::vector>& indices_of_by_key() { + return indices_of_by_key_; + } + + static Status is_valid_on_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_by_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_data_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for data field ", field->name(), " : ", + field->type()->ToString()); + } + } + static arrow::Result> MakeOutputSchema( - const std::vector& inputs, const AsofJoinNodeOptions& options) { + const std::vector& inputs, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key) { std::vector> fields; - const auto& on_field_name = *options.on_key.name(); - const auto& by_field_name = *options.by_key.name(); - + size_t n_by = indices_of_by_key[0].size(); + const DataType* on_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); - const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); - const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; - if ((on_field_ix == -1) | (by_field_ix == -1)) { + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { return Status::Invalid("Missing join key on table ", j); } + const auto& on_field = input_schema->fields()[on_field_ix]; + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + } + + if (on_key_type == NULLPTR) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field->type(), " for field ", on_field->name(), + " in input ", j); + } + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected on-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); + } + } + for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); - if (field->name() == on_field_name) { - if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { - return Status::Invalid("Unsupported type for on key: ", field->name()); - } + if (i == on_field_ix) { + ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table if (j == 0) { fields.push_back(field); } - } else if (field->name() == by_field_name) { - if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { - return Status::Invalid("Unsupported type for by key: ", field->name()); - } + } else if (std_has(by_field_ix, i)) { + ARROW_RETURN_NOT_OK(is_valid_by_field(field)); // Only add by field from the left table if (j == 0) { fields.push_back(field); } } else { - if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { - return Status::Invalid("Unsupported data type: ", field->name()); - } - + ARROW_RETURN_NOT_OK(is_valid_data_field(field)); fields.push_back(field); } } @@ -657,45 +1013,91 @@ class AsofJoinNode : public ExecNode { return std::make_shared(fields); } + static inline Result FindColIndex(const Schema& schema, + const FieldRef& field_ref, + util::string_view key_kind) { + auto match_res = field_ref.FindOne(schema); + if (!match_res.ok()) { + return Status::Invalid("Bad join key on table : ", match_res.status().message()); + } + ARROW_ASSIGN_OR_RAISE(auto match, match_res); + if (match.indices().size() != 1) { + return Status::Invalid("AsOfJoinNode does not support a nested ", + to_string(key_kind), "-key ", field_ref.ToString()); + } + return match.indices()[0]; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, join_options)); + if (join_options.tolerance < 0) { + return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", + join_options.tolerance); + } - std::vector input_labels(inputs.size()); - input_labels[0] = "left"; - for (size_t i = 1; i < inputs.size(); ++i) { - input_labels[i] = "right_" + std::to_string(i); + size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + std::vector input_labels(n_input); + std::vector indices_of_on_key(n_input); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); + const Schema& input_schema = *inputs[i]->output_schema(); + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(input_schema, join_options.on_key, "on")); + for (size_t k = 0; k < n_by; k++) { + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(input_schema, join_options.by_key[k], "by")); + } } - return plan->EmplaceNode(plan, inputs, std::move(input_labels), - join_options, std::move(output_schema)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + + std::vector> key_hashers; + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(indices_of_by_key[i])); + } + bool must_hash = + n_by > 1 || + (n_by == 1 && + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); + bool may_rehash = n_by == 1 && !must_hash; + return plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), time_value(join_options.tolerance), + std::move(output_schema), std::move(key_hashers), must_hash, may_rehash); } const char* kind_name() const override { return "AsofJoinNode"; } void InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - state_.at(k)->Push(rb); + Status st = state_.at(k)->Push(rb); + if (!st.ok()) { + ErrorReceived(input, st); + return; + } process_.Push(true); } void ErrorReceived(ExecNode* input, Status error) override { outputs_[0]->ErrorReceived(this, std::move(error)); - StopProducing(); } void InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); state_.at(k)->set_total_batches(total_batches); } // Trigger a process call @@ -714,20 +1116,24 @@ class AsofJoinNode : public ExecNode { DCHECK_EQ(output, outputs_[0]); StopProducing(); } - void StopProducing() override { finished_.MarkFinished(); } + void StopProducing() override { + process_.Clear(); + process_.Push(false); + } arrow::Future<> finished() override { return finished_; } private: - static const std::set> kSupportedOnTypes_; - static const std::set> kSupportedByTypes_; - static const std::set> kSupportedDataTypes_; - arrow::Future<> finished_; + std::vector indices_of_on_key_; + std::vector> indices_of_by_key_; + std::vector> key_hashers_; + bool must_hash_; + bool may_rehash_; // InputStates // Each input state correponds to an input table std::vector> state_; std::mutex gate_; - AsofJoinNodeOptions options_; + OnType tolerance_; // Queue for triggering processing of a given input // (a false value is a poison pill) @@ -741,30 +1147,25 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema) + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, + bool must_hash, bool may_rehash) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - options_(join_options), + indices_of_on_key_(std::move(indices_of_on_key)), + indices_of_by_key_(std::move(indices_of_by_key)), + key_hashers_(std::move(key_hashers)), + must_hash_(must_hash), + may_rehash_(may_rehash), + tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - for (size_t i = 0; i < inputs.size(); ++i) - state_.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); - finished_ = arrow::Future<>::MakeFinished(); } -// Currently supported types -const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; -const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; -const std::set> AsofJoinNode::kSupportedDataTypes_ = { - int32(), int64(), float32(), float64()}; - namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 8b993764abe..48d1ae6410b 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -17,11 +17,13 @@ #include +#include #include #include #include #include "arrow/api.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" @@ -32,23 +34,185 @@ #include "arrow/testing/random.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" +#define TRACED_TEST(t_class, t_name, t_body) \ + TEST(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name); \ + t_body; \ + } + +#define TRACED_TEST_P(t_class, t_name, t_body) \ + TEST_P(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name "_" + std::get<1>(GetParam())); \ + t_body; \ + } + using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { +bool is_temporal_primitive(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + case Type::DATE32: + case Type::DATE64: + case Type::TIMESTAMP: + return true; + default: + return false; + } +} + +Result MakeBatchesFromNumString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + FieldVector num_fields; + for (auto field : schema->fields()) { + num_fields.push_back( + is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + } + auto num_schema = + std::make_shared(num_fields, schema->endianness(), schema->metadata()); + BatchesWithSchema num_batches = + MakeBatchesFromString(num_schema, json_strings, multiplicity); + BatchesWithSchema batches; + batches.schema = schema; + int n_fields = schema->num_fields(); + for (auto num_batch : num_batches.batches) { + std::vector values; + for (int i = 0; i < n_fields; i++) { + auto type = schema->field(i)->type(); + if (is_base_binary_like(type->id())) { + // casting to string first enables casting to binary + ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); + ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); + values.push_back(as_type); + } else { + values.push_back(num_batch.values[i]); + } + } + ExecBatch batch(values, num_batch.length); + batches.batches.push_back(batch); + } + return batches; +} + +void BuildNullArray(std::shared_ptr& empty, const std::shared_ptr& type, + int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK(builder->AppendNulls(length)); + ASSERT_OK(builder->Finish(&empty)); +} + +void BuildZeroPrimitiveArray(std::shared_ptr& empty, + const std::shared_ptr& type, int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(type, 0)); + ASSERT_OK(builder->AppendScalar(*scalar, length)); + ASSERT_OK(builder->Finish(&empty)); +} + +template +void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { + Builder builder(default_memory_pool()); + ASSERT_OK(builder.Reserve(length)); + for (int64_t i = 0; i < length; i++) { + ASSERT_OK(builder.Append("0", /*length=*/1)); + } + ASSERT_OK(builder.Finish(&empty)); +} + +// mutates by copying from_key into to_key and changing from_key to zero +Result MutateByKey(BatchesWithSchema& batches, std::string from_key, + std::string to_key, bool replace_key = false, + bool null_key = false, bool remove_key = false) { + int from_index = batches.schema->GetFieldIndex(from_key); + int n_fields = batches.schema->num_fields(); + auto fields = batches.schema->fields(); + BatchesWithSchema new_batches; + if (remove_key) { + ARROW_ASSIGN_OR_RAISE(new_batches.schema, batches.schema->RemoveField(from_index)); + } else { + auto new_field = batches.schema->field(from_index)->WithName(to_key); + ARROW_ASSIGN_OR_RAISE(new_batches.schema, + replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)); + } + for (const ExecBatch& batch : batches.batches) { + std::vector new_values; + for (int i = 0; i < n_fields; i++) { + const Datum& value = batch.values[i]; + if (i == from_index) { + if (remove_key) { + continue; + } + auto type = fields[i]->type(); + if (null_key) { + std::shared_ptr empty; + BuildNullArray(empty, type, batch.length); + new_values.push_back(empty); + } else if (is_primitive(type->id())) { + std::shared_ptr empty; + BuildZeroPrimitiveArray(empty, type, batch.length); + new_values.push_back(empty); + } else if (is_base_binary_like(type->id())) { + std::shared_ptr empty; + switch (type->id()) { + case Type::STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + default: + DCHECK(false); + break; + } + new_values.push_back(empty); + } else { + ARROW_ASSIGN_OR_RAISE(auto sub, Subtract(value, value)); + new_values.push_back(sub); + } + if (replace_key) { + continue; + } + } + new_values.push_back(value); + } + new_batches.batches.emplace_back(new_values, batch.length); + } + return new_batches; +} + +// code generation for the by_key types supported by AsofJoinNodeOptions constructors +// which cannot be directly done using templates because of failure to deduce the template +// argument for an invocation with a string- or initializer_list-typed keys-argument +#define EXPAND_BY_KEY_TYPE(macro) \ + macro(const FieldRef); \ + macro(std::vector); \ + macro(std::initializer_list); + void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, const BatchesWithSchema& r1_batches, - const BatchesWithSchema& exp_batches, const FieldRef time, - const FieldRef keys, const int64_t tolerance) { + const BatchesWithSchema& exp_batches, + const AsofJoinNodeOptions join_options) { auto exec_ctx = arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options(time, keys, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ @@ -64,6 +228,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, .AddToPlan(plan.get())); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + for (auto batch : res) { + ASSERT_EQ(exp_batches.schema->num_fields(), batch.values.size()); + } ASSERT_OK_AND_ASSIGN(auto exp_table, TableFromExecBatches(exp_batches.schema, exp_batches.batches)); @@ -74,237 +241,783 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, /*same_chunk_layout=*/true, /*flatten=*/true); } -void DoRunBasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, int64_t tolerance) { - auto l_schema = - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); - auto r0_schema = - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); - auto r1_schema = - schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); - - auto exp_schema = schema({ - field("time", int64()), - field("key", int32()), - field("l_v0", float64()), - field("r0_v0", float64()), - field("r1_v0", float32()), - }); - - // Test three table join - BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; - l_batches = MakeBatchesFromString(l_schema, l_data); - r0_batches = MakeBatchesFromString(r0_schema, r0_data); - r1_batches = MakeBatchesFromString(r1_schema, r1_data); - exp_batches = MakeBatchesFromString(exp_schema, exp_data); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", - tolerance); -} - -void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { - BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); - +#define CHECK_RUN_OUTPUT(by_key_type) \ + void CheckRunOutput( \ + const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, \ + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ + const FieldRef time, by_key_type key, const int64_t tolerance) { \ + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ + AsofJoinNodeOptions(time, {key}, tolerance)); \ + } + +EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) + +void DoInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str, + bool fail_on_plan_creation = false) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsofJoinNodeOptions join_options("time", "key", 0); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); + if (fail_on_plan_creation) { + AsyncGenerator> sink_gen; + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr(expected_error_str), + StartAndCollect(plan.get(), sink_gen)); + } else { + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), + join.AddToPlan(plan.get())); + } +} + +void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {R"([])"})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {R"([])"})); + + return DoRunInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, int64_t tolerance, + const std::string& expected_error_str) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"key"}, tolerance), + expected_error_str); +} + +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for "); +} + +void DoRunInvalidToleranceTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, -1, + "AsOfJoin tolerance must be non-negative but is "); +} + +void DoRunMissingKeysTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); +} + +void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("invalid_time", {"key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"invalid_key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), + "Bad join key on table : No match"); +} + +void DoRunAmbiguousOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +// Gets a batch for testing as a Json string +// The batch will have n_rows rows n_cols columns, the first column being the on-field +// If unordered is true then the first column will be out-of-order +std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered = false) { + int order_mask = unordered ? 1 : 0; + std::stringstream s; + s << '['; + for (int i = 0; i < n_rows; i++) { + if (i > 0) { + s << ", "; + } + s << '['; + for (int j = 0; j < n_cols; j++) { + if (j > 0) { + s << ", " << j; + } else if (j < 2) { + s << (i ^ order_mask); + } else { + s << i; + } + } + s << ']'; + } + s << ']'; + return s.str(); +} + +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + ASSERT_TRUE(l_unordered || r_unordered); + int n_rows = 5; + auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields(), l_unordered); + auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields(), r_unordered); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); + + return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, + /*then_run_plan=*/true); } +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, + AsofJoinNodeOptions("time", {"key"}, 1000), + "out-of-order on-key values"); +} + +struct BasicTestTypes { + std::shared_ptr time, key, l_val, r0_val, r1_val; +}; + +struct BasicTest { + BasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_nokey_data, + const std::vector& exp_emptykey_data, + const std::vector& exp_data, int64_t tolerance) + : l_data(std::move(l_data)), + r0_data(std::move(r0_data)), + r1_data(std::move(r1_data)), + exp_nokey_data(std::move(exp_nokey_data)), + exp_emptykey_data(std::move(exp_emptykey_data)), + exp_data(std::move(exp_data)), + tolerance(tolerance) {} + + static inline void check_init(const std::vector>& types) { + ASSERT_NE(0, types.size()); + } + + template + static inline std::vector> init_types( + const std::vector>& all_types, TypeCond type_cond) { + std::vector> types; + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); + } + } + check_init(types); + return types; + } + + void RunSingleByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); + }); + } + static void DoSingleByKey(BasicTest& basic_tests) { basic_tests.RunSingleByKey(); } + void RunDoubleByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key"}, tolerance); + }); + } + static void DoDoubleByKey(BasicTest& basic_tests) { basic_tests.RunDoubleByKey(); } + void RunMutateByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(exp_batches, MutateByKey(exp_batches, "key", "key2")); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key2"}, tolerance); + }); + } + static void DoMutateByKey(BasicTest& basic_tests) { basic_tests.RunMutateByKey(); } + void RunMutateNoKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", + tolerance); + }); + } + static void DoMutateNoKey(BasicTest& basic_tests) { basic_tests.RunMutateNoKey(); } + void RunMutateNullKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true, true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, + AsofJoinNodeOptions("time", {"key2"}, tolerance)); + }); + } + static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } + void RunMutateEmptyKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key", false, false, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key", false, false, true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, + AsofJoinNodeOptions("time", {}, tolerance)); + }); + } + static void DoMutateEmptyKey(BasicTest& basic_tests) { + basic_tests.RunMutateEmptyKey(); + } + template + void RunBatches(BatchesRunner batches_runner) { + std::vector> all_types = { + utf8(), + large_utf8(), + binary(), + large_binary(), + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; + using T = const std::shared_ptr; + // byte_width > 1 below allows fitting the tested data + auto time_types = init_types( + all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto l_types = init_types(all_types, [](T& t) { return true; }); + auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); + auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); + + // sample a limited number of type-combinations to keep the runnning time reasonable + // the scoped-traces below help reproduce a test failure, should it happen + auto start_time = std::chrono::system_clock::now(); + auto seed = start_time.time_since_epoch().count(); + ARROW_SCOPED_TRACE("Types seed: ", seed); + std::default_random_engine engine(static_cast(seed)); + std::uniform_int_distribution time_distribution(0, time_types.size() - 1); + std::uniform_int_distribution key_distribution(0, key_types.size() - 1); + std::uniform_int_distribution l_distribution(0, l_types.size() - 1); + std::uniform_int_distribution r0_distribution(0, r0_types.size() - 1); + std::uniform_int_distribution r1_distribution(0, r1_types.size() - 1); + + for (int i = 0; i < 1000; i++) { + auto time_type = time_types[time_distribution(engine)]; + ARROW_SCOPED_TRACE("Time type: ", *time_type); + auto key_type = key_types[key_distribution(engine)]; + ARROW_SCOPED_TRACE("Key type: ", *key_type); + auto l_type = l_types[l_distribution(engine)]; + ARROW_SCOPED_TRACE("Left type: ", *l_type); + auto r0_type = r0_types[r0_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type); + auto r1_type = r1_types[r1_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type); + + RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); + + auto end_time = std::chrono::system_clock::now(); + std::chrono::duration diff = end_time - start_time; + if (diff.count() > 2) { + // this normally happens on slow CI systems, but is fine + break; + } + } + } + template + void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner) { + const BasicTestTypes& b = basic_test_types; + auto l_schema = + schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); + auto r0_schema = + schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); + auto r1_schema = + schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); + + auto exp_schema = schema({ + field("time", b.time), + field("key", b.key), + field("l_v0", b.l_val), + field("r0_v0", b.r0_val), + field("r1_v0", b.r1_val), + }); + + // Test three table join + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, l_data)); + ASSERT_OK_AND_ASSIGN(auto r0_batches, MakeBatchesFromNumString(r0_schema, r0_data)); + ASSERT_OK_AND_ASSIGN(auto r1_batches, MakeBatchesFromNumString(r1_schema, r1_data)); + ASSERT_OK_AND_ASSIGN(auto exp_nokey_batches, + MakeBatchesFromNumString(exp_schema, exp_nokey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_emptykey_batches, + MakeBatchesFromNumString(exp_schema, exp_emptykey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_batches, + MakeBatchesFromNumString(exp_schema, exp_data)); + batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, + exp_emptykey_batches, exp_batches); + } + + std::vector l_data; + std::vector r0_data; + std::vector r1_data; + std::vector exp_nokey_data; + std::vector exp_emptykey_data; + std::vector exp_data; + int64_t tolerance; +}; + +using AsofJoinBasicParams = std::tuple, std::string>; + +struct AsofJoinBasicTest : public testing::TestWithParam {}; + class AsofJoinTest : public testing::Test {}; -TEST(AsofJoinTest, TestBasic1) { +BasicTest GetBasicTest1() { // Single key, single batch - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])"}, - /*r1*/ {R"([[1000, 1, 101.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])"}, + /*r1*/ {R"([[1000, 1, 101]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, + /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } -TEST(AsofJoinTest, TestBasic2) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic1, { + BasicTest basic_test = GetBasicTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest2() { // Single key, multiple batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -TEST(AsofJoinTest, TestBasic3) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic2, { + BasicTest basic_test = GetBasicTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest3() { // Single key, multiple left batches, single right batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -TEST(AsofJoinTest, TestBasic4) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest4() { // Multi key, multiple batches, misaligned batches - DoRunBasicTest( + return BasicTest( /*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } -TEST(AsofJoinTest, TestBasic5) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic4, { + BasicTest basic_test = GetBasicTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest5() { // Multi key, multiple batches, misaligned batches, smaller tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 500); -} - -TEST(AsofJoinTest, TestBasic6) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 500); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest6() { // Multi key, multiple batches, misaligned batches, zero tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, - 0); -} - -TEST(AsofJoinTest, TestEmpty1) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, + 0); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic6, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic6_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest6(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest1() { // Empty left batch - DoRunBasicTest(/*l*/ - {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty2) { + return BasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty1_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest2() { // Empty left input - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([])"}, 1000); -} - -TEST(AsofJoinTest, TestEmpty3) { + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty2_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest3() { // Empty right batch - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty4) { - // Empty right input - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", - R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty5) { - // All empty - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([])"}, - /*exp*/ - {R"([])"}, 1000); + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 1000); } -TEST(AsofJoinTest, TestUnsupportedOntype) { - DoRunInvalidTypeTest( - schema({field("time", utf8()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", utf8()), field("key", int32()), field("r0_v0", float32())})); +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest4() { + // Empty right input + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])", + R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, null, 1002], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, + 1000); } -TEST(AsofJoinTest, TestUnsupportedBytype) { - DoRunInvalidTypeTest( - schema({field("time", int64()), field("key", utf8()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", utf8()), field("r0_v0", float32())})); +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty4_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest5() { + // All empty + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); } -TEST(AsofJoinTest, TestUnsupportedDatatype) { - // Utf8 is unsupported +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +INSTANTIATE_TEST_SUITE_P( + AsofJoinNodeTest, AsofJoinBasicTest, + testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "SingleByKey"), + AsofJoinBasicParams(BasicTest::DoDoubleByKey, "DoubleByKey"), + AsofJoinBasicParams(BasicTest::DoMutateByKey, "MutateByKey"), + AsofJoinBasicParams(BasicTest::DoMutateNoKey, "MutateNoKey"), + AsofJoinBasicParams(BasicTest::DoMutateNullKey, "MutateNullKey"), + AsofJoinBasicParams(BasicTest::DoMutateEmptyKey, "MutateEmptyKey"))); + +TRACED_TEST(AsofJoinTest, TestUnsupportedOntype, { + DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", list(int32())), field("key", int32()), + field("r0_v0", float32())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedBytype, { + DoRunInvalidTypeTest(schema({field("time", int64()), field("key", list(int32())), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", list(int32())), + field("r0_v0", float32())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype, { + // List is unsupported DoRunInvalidTypeTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); -} + schema({field("time", int64()), field("key", int32()), + field("r0_v0", list(int32()))})); +}) -TEST(AsofJoinTest, TestMissingKeys) { - DoRunInvalidTypeTest( +TRACED_TEST(AsofJoinTest, TestMissingKeys, { + DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), schema( {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); -} +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, { + // Utf8 is unsupported + DoRunInvalidToleranceTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestMissingOnKey, { + DoRunMissingOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestMissingByKey, { + DoRunMissingByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestNestedOnKey, { + DoRunNestedOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestNestedByKey, { + DoRunNestedByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey, { + DoRunAmbiguousOnKeyTest( + schema({field("time", int64()), field("time", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestAmbiguousByKey, { + DoRunAmbiguousByKeyTest( + schema({field("time", int64()), field("key", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/false, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/false, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 5cf66b3d09e..da1710fe08d 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -26,7 +26,6 @@ #include #include "arrow/compute/exec/hash_join_dict.h" -#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/task_util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/row/encode_internal.h" diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a8e8c1ee230..e0172bff7f7 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -397,23 +397,25 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} - /// \brief "on" key for the join. Each + /// \brief "on" key for the join. /// - /// All inputs tables must be sorted by the "on" key. Inexact - /// match is used on the "on" key. i.e., a row is considiered match iff + /// All inputs tables must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff /// left_on - tolerance <= right_on <= left_on. - /// Currently, "on" key must be an int64 field + /// Currently, the "on" key must be of an integer, date, or timestamp type. FieldRef on_key; /// \brief "by" key for the join. /// /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. - /// Currently, the "by" key must be an int32 field - FieldRef by_key; - /// Tolerance for inexact "on" key matching + /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type + std::vector by_key; + /// \brief Tolerance for inexact "on" key matching. Must be non-negative. + /// + /// The tolerance is interpreted in the same units as the "on" key. int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index 4bf3574d09f..9ea609c5310 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -141,6 +141,12 @@ Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows) { ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata, ColumnMetadataFromDataType(array_data->type)); + return ColumnArrayFromArrayDataAndMetadata(array_data, metadata, start_row, num_rows); +} + +KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows) { KeyColumnArray column_array = KeyColumnArray( metadata, array_data->offset + start_row + num_rows, array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index 0620f6d3eb1..389b63cca41 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -135,7 +135,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type uint32_t* mutable_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the offsets buffer @@ -143,7 +143,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type const uint32_t* offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return a mutable version of the large-offsets buffer @@ -151,7 +151,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type uint64_t* mutable_large_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the large-offsets buffer @@ -159,7 +159,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type const uint64_t* large_offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return the type metadata @@ -205,6 +205,17 @@ ARROW_EXPORT Result ColumnMetadataFromDataType( ARROW_EXPORT Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows); +/// \brief Create KeyColumnArray from ArrayData and KeyColumnMetadata +/// +/// If `type` is a dictionary type then this will return the KeyColumnArray for +/// the indices array +/// +/// The caller should ensure this is only called on "key" columns. +/// \see ColumnMetadataFromDataType for details +ARROW_EXPORT KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows); + /// \brief Create KeyColumnMetadata instances from an ExecBatch /// /// column_metadatas will be resized to fit diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 66da3cadcb5..e2b74e865fd 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -622,6 +622,13 @@ using is_fixed_size_binary_type = std::is_base_of; template using enable_if_fixed_size_binary = enable_if_t::value, R>; +// This includes primitive, dictionary, and fixed-size-binary types +template +using is_fixed_width_type = std::is_base_of; + +template +using enable_if_fixed_width_type = enable_if_t::value, R>; + template using is_binary_like_type = std::integral_constant::value &&