diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 3b49f4ca00a..4702a427bad 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -381,6 +381,7 @@ if(ARROW_COMPUTE) compute/cast.cc compute/exec.cc compute/exec/aggregate_node.cc + compute/exec/asof_join_node.cc compute/exec/bloom_filter.cc compute/exec/exec_plan.cc compute/exec/expression.cc diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index b2a21c2bd6b..f19921f4f28 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -32,6 +32,11 @@ add_arrow_compute_test(hash_join_node_test hash_join_node_test.cc bloom_filter_test.cc key_hash_test.cc) +add_arrow_compute_test(asof_join_node_test + PREFIX + "arrow-compute" + SOURCES + asof_join_node_test.cc) add_arrow_compute_test(tpch_node_test PREFIX "arrow-compute") add_arrow_compute_test(union_node_test PREFIX "arrow-compute") add_arrow_compute_test(util_test PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc new file mode 100644 index 00000000000..93eca8dbfb6 --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -0,0 +1,773 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/optional.h" + +namespace arrow { +namespace compute { + +// Remove this when multiple keys and/or types is supported +typedef int32_t KeyType; + +// Maximum number of tables that can be joined +#define MAX_JOIN_TABLES 64 +typedef uint64_t row_index_t; +typedef int col_index_t; + +/** + * Simple implementation for an unbound concurrent queue + */ +template +class ConcurrentQueue { + public: + T Pop() { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&] { return !queue_.empty(); }); + auto item = queue_.front(); + queue_.pop(); + return item; + } + + void Push(const T& item) { + std::unique_lock lock(mutex_); + queue_.push(item); + cond_.notify_one(); + } + + util::optional TryPop() { + // Try to pop the oldest value from the queue (or return nullopt if none) + std::unique_lock lock(mutex_); + if (queue_.empty()) { + return util::nullopt; + } else { + auto item = queue_.front(); + queue_.pop(); + return item; + } + } + + bool Empty() const { + std::unique_lock lock(mutex_); + return queue_.empty(); + } + + // Un-synchronized access to front + // For this to be "safe": + // 1) the caller logically guarantees that queue is not empty + // 2) pop/try_pop cannot be called concurrently with this + const T& UnsyncFront() const { return queue_.front(); } + + private: + std::queue queue_; + mutable std::mutex mutex_; + std::condition_variable cond_; +}; + +struct MemoStore { + // Stores last known values for all the keys + + struct Entry { + // Timestamp associated with the entry + int64_t time; + + // Batch associated with the entry (perf is probably OK for this; batches change + // rarely) + std::shared_ptr batch; + + // Row associated with the entry + row_index_t row; + }; + + std::unordered_map entries_; + + void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, + KeyType key) { + auto& e = entries_[key]; + // that we can do this assignment optionally, is why we + // can get array with using shared_ptr above (the batch + // shouldn't change that often) + if (e.batch != batch) e.batch = batch; + e.row = row; + e.time = time; + } + + util::optional GetEntryForKey(KeyType key) const { + auto e = entries_.find(key); + if (entries_.end() == e) return util::nullopt; + return util::optional(&e->second); + } + + void RemoveEntriesWithLesserTime(int64_t ts) { + for (auto e = entries_.begin(); e != entries_.end();) + if (e->second.time < ts) + e = entries_.erase(e); + else + ++e; + } +}; + +class InputState { + // InputState correponds to an input + // Input record batches are queued up in InputState until processed and + // turned into output record batches. + + public: + InputState(const std::shared_ptr& schema, + const std::string& time_col_name, const std::string& key_col_name) + : queue_(), + schema_(schema), + time_col_index_(schema->GetFieldIndex(time_col_name)), + key_col_index_(schema->GetFieldIndex(key_col_name)) {} + + col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { + src_to_dst_.resize(schema_->num_fields()); + for (int i = 0; i < schema_->num_fields(); ++i) + if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) + src_to_dst_[i] = dst_offset++; + return dst_offset; + } + + const util::optional& MapSrcToDst(col_index_t src) const { + return src_to_dst_[src]; + } + + bool IsTimeOrKeyColumn(col_index_t i) const { + DCHECK_LT(i, schema_->num_fields()); + return (i == time_col_index_) || (i == key_col_index_); + } + + // Gets the latest row index, assuming the queue isn't empty + row_index_t GetLatestRow() const { return latest_ref_row_; } + + bool Empty() const { + // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below + if (latest_ref_row_ > 0) return false; + return queue_.Empty(); + } + + int total_batches() const { return total_batches_; } + + // Gets latest batch (precondition: must not be empty) + const std::shared_ptr& GetLatestBatch() const { + return queue_.UnsyncFront(); + } + + KeyType GetLatestKey() const { + return queue_.UnsyncFront() + ->column_data(key_col_index_) + ->GetValues(1)[latest_ref_row_]; + } + + int64_t GetLatestTime() const { + return queue_.UnsyncFront() + ->column_data(time_col_index_) + ->GetValues(1)[latest_ref_row_]; + } + + bool Finished() const { return batches_processed_ == total_batches_; } + + bool Advance() { + // Try advancing to the next row and update latest_ref_row_ + // Returns true if able to advance, false if not. + bool have_active_batch = + (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); + + if (have_active_batch) { + // If we have an active batch + if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { + // hit the end of the batch, need to get the next batch if possible. + ++batches_processed_; + latest_ref_row_ = 0; + have_active_batch &= !queue_.TryPop(); + if (have_active_batch) + DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed + } + } + return have_active_batch; + } + + // Advance the data to be immediately past the specified timestamp, update + // latest_time and latest_ref_row to the value that immediately pass the + // specified timestamp. + // Returns true if updates were made, false if not. + bool AdvanceAndMemoize(int64_t ts) { + // Advance the right side row index until we reach the latest right row (for each key) + // for the given left timestamp. + + // Check if already updated for TS (or if there is no latest) + if (Empty()) return false; // can't advance if empty + auto latest_time = GetLatestTime(); + if (latest_time > ts) return false; // already advanced + + // Not updated. Try to update and possibly advance. + bool updated = false; + do { + latest_time = GetLatestTime(); + // if Advance() returns true, then the latest_ts must also be valid + // Keep advancing right table until we hit the latest row that has + // timestamp <= ts. This is because we only need the latest row for the + // match given a left ts. + if (latest_time <= ts) { + memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); + } else { + break; // hit a future timestamp -- done updating for now + } + updated = true; + } while (Advance()); + return updated; + } + + void Push(const std::shared_ptr& rb) { + if (rb->num_rows() > 0) { + queue_.Push(rb); + } else { + ++batches_processed_; // don't enqueue empty batches, just record as processed + } + } + + util::optional GetMemoEntryForKey(KeyType key) { + return memo_.GetEntryForKey(key); + } + + util::optional GetMemoTimeForKey(KeyType key) { + auto r = GetMemoEntryForKey(key); + if (r.has_value()) { + return (*r)->time; + } else { + return util::nullopt; + } + } + + void RemoveMemoEntriesWithLesserTime(int64_t ts) { + memo_.RemoveEntriesWithLesserTime(ts); + } + + const std::shared_ptr& get_schema() const { return schema_; } + + void set_total_batches(int n) { + DCHECK_GE(n, 0); + DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; + total_batches_ = n; + } + + private: + // Pending record batches. The latest is the front. Batches cannot be empty. + ConcurrentQueue> queue_; + // Schema associated with the input + std::shared_ptr schema_; + // Total number of batches (only int because InputFinished uses int) + int total_batches_ = -1; + // Number of batches processed so far (only int because InputFinished uses int) + int batches_processed_ = 0; + // Index of the time col + col_index_t time_col_index_; + // Index of the key col + col_index_t key_col_index_; + // Index of the latest row reference within; if >0 then queue_ cannot be empty + // Must be < queue_.front()->num_rows() if queue_ is non-empty + row_index_t latest_ref_row_ = 0; + // Stores latest known values for the various keys + MemoStore memo_; + // Mapping of source columns to destination columns + std::vector> src_to_dst_; +}; + +template +struct CompositeReferenceRow { + struct Entry { + arrow::RecordBatch* batch; // can be NULL if there's no value + row_index_t row; + }; + Entry refs[MAX_TABLES]; +}; + +// A table of composite reference rows. Rows maintain pointers to the +// constituent record batches, but the overall table retains shared_ptr +// references to ensure memory remains resident while the table is live. +// +// The main reason for this is that, especially for wide tables, joins +// are effectively row-oriented, rather than column-oriented. Separating +// the join part from the columnar materialization part simplifies the +// logic around data types and increases efficiency. +// +// We don't put the shared_ptr's into the rows for efficiency reasons. +template +class CompositeReferenceTable { + public: + explicit CompositeReferenceTable(size_t n_tables) : n_tables_(n_tables) { + DCHECK_GE(n_tables_, 1); + DCHECK_LE(n_tables_, MAX_TABLES); + } + + size_t n_rows() const { return rows_.size(); } + + // Adds the latest row from the input state as a new composite reference row + // - LHS must have a valid key,timestep,and latest rows + // - RHS must have valid data memo'ed for the key + void Emplace(std::vector>& in, int64_t tolerance) { + DCHECK_EQ(in.size(), n_tables_); + + // Get the LHS key + KeyType key = in[0]->GetLatestKey(); + + // Add row and setup LHS + // (the LHS state comes just from the latest row of the LHS table) + DCHECK(!in[0]->Empty()); + const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); + row_index_t lhs_latest_row = in[0]->GetLatestRow(); + int64_t lhs_latest_time = in[0]->GetLatestTime(); + if (0 == lhs_latest_row) { + // On the first row of the batch, we resize the destination. + // The destination size is dictated by the size of the LHS batch. + row_index_t new_batch_size = lhs_latest_batch->num_rows(); + row_index_t new_capacity = rows_.size() + new_batch_size; + if (rows_.capacity() < new_capacity) rows_.reserve(new_capacity); + } + rows_.resize(rows_.size() + 1); + auto& row = rows_.back(); + row.refs[0].batch = lhs_latest_batch.get(); + row.refs[0].row = lhs_latest_row; + AddRecordBatchRef(lhs_latest_batch); + + // Get the state for that key from all on the RHS -- assumes it's up to date + // (the RHS state comes from the memoized row references) + for (size_t i = 1; i < in.size(); ++i) { + util::optional opt_entry = in[i]->GetMemoEntryForKey(key); + if (opt_entry.has_value()) { + DCHECK(*opt_entry); + if ((*opt_entry)->time + tolerance >= lhs_latest_time) { + // Have a valid entry + const MemoStore::Entry* entry = *opt_entry; + row.refs[i].batch = entry->batch.get(); + row.refs[i].row = entry->row; + AddRecordBatchRef(entry->batch); + continue; + } + } + row.refs[i].batch = NULL; + row.refs[i].row = 0; + } + } + + // Materializes the current reference table into a target record batch + Result> Materialize( + const std::shared_ptr& output_schema, + const std::vector>& state) { + DCHECK_EQ(state.size(), n_tables_); + + // Don't build empty batches + size_t n_rows = rows_.size(); + if (!n_rows) return NULLPTR; + + // Build the arrays column-by-column from the rows + std::vector> arrays(output_schema->num_fields()); + for (size_t i_table = 0; i_table < n_tables_; ++i_table) { + int n_src_cols = state.at(i_table)->get_schema()->num_fields(); + { + for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { + util::optional i_dst_col_opt = + state[i_table]->MapSrcToDst(i_src_col); + if (!i_dst_col_opt) continue; + col_index_t i_dst_col = *i_dst_col_opt; + const auto& src_field = state[i_table]->get_schema()->field(i_src_col); + const auto& dst_field = output_schema->field(i_dst_col); + DCHECK(src_field->type()->Equals(dst_field->type())); + DCHECK_EQ(src_field->name(), dst_field->name()); + const auto& field_type = src_field->type(); + + if (field_type->Equals(arrow::int32())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (MaterializePrimitiveColumn(i_table, + i_src_col))); + } else if (field_type->Equals(arrow::int64())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (MaterializePrimitiveColumn(i_table, + i_src_col))); + } else if (field_type->Equals(arrow::float32())) { + ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), + (MaterializePrimitiveColumn( + i_table, i_src_col))); + } else if (field_type->Equals(arrow::float64())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (MaterializePrimitiveColumn(i_table, + i_src_col))); + } else { + ARROW_RETURN_NOT_OK( + Status::Invalid("Unsupported data type: ", src_field->name())); + } + } + } + } + + // Build the result + DCHECK_LE(n_rows, (uint64_t)std::numeric_limits::max()); + std::shared_ptr r = + arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); + return r; + } + + // Returns true if there are no rows + bool empty() const { return rows_.empty(); } + + private: + // Contains shared_ptr refs for all RecordBatches referred to by the contents of rows_ + std::unordered_map> _ptr2ref; + + // Row table references + std::vector> rows_; + + // Total number of tables in the composite table + size_t n_tables_; + + // Adds a RecordBatch ref to the mapping, if needed + void AddRecordBatchRef(const std::shared_ptr& ref) { + if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; + } + + template + Result> MaterializePrimitiveColumn(size_t i_table, + col_index_t i_col) { + Builder builder; + ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); + for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { + const auto& ref = rows_[i_row].refs[i_table]; + if (ref.batch) { + builder.UnsafeAppend( + ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); + } else { + builder.UnsafeAppendNull(); + } + } + std::shared_ptr result; + ARROW_RETURN_NOT_OK(builder.Finish(&result)); + return result; + } +}; + +class AsofJoinNode : public ExecNode { + // Advances the RHS as far as possible to be up to date for the current LHS timestamp + bool UpdateRhs() { + auto& lhs = *state_.at(0); + auto lhs_latest_time = lhs.GetLatestTime(); + bool any_updated = false; + for (size_t i = 1; i < state_.size(); ++i) + any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time); + return any_updated; + } + + // Returns false if RHS not up to date for LHS + bool IsUpToDateWithLhsRow() const { + auto& lhs = *state_[0]; + if (lhs.Empty()) return false; // can't proceed if nothing on the LHS + int64_t lhs_ts = lhs.GetLatestTime(); + for (size_t i = 1; i < state_.size(); ++i) { + auto& rhs = *state_[i]; + if (!rhs.Finished()) { + // If RHS is finished, then we know it's up to date + if (rhs.Empty()) + return false; // RHS isn't finished, but is empty --> not up to date + if (lhs_ts >= rhs.GetLatestTime()) + return false; // RHS isn't up to date (and not finished) + } + } + return true; + } + + Result> ProcessInner() { + DCHECK(!state_.empty()); + auto& lhs = *state_.at(0); + + // Construct new target table if needed + CompositeReferenceTable dst(state_.size()); + + // Generate rows into the dst table until we either run out of data or hit the row + // limit, or run out of input + for (;;) { + // If LHS is finished or empty then there's nothing we can do here + if (lhs.Finished() || lhs.Empty()) break; + + // Advance each of the RHS as far as possible to be up to date for the LHS timestamp + bool any_rhs_advanced = UpdateRhs(); + + // If we have received enough inputs to produce the next output batch + // (decided by IsUpToDateWithLhsRow), we will perform the join and + // materialize the output batch. The join is done by advancing through + // the LHS and adding joined row to rows_ (done by Emplace). Finally, + // input batches that are no longer needed are removed to free up memory. + if (IsUpToDateWithLhsRow()) { + dst.Emplace(state_, options_.tolerance); + if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch + } else { + if (!any_rhs_advanced) break; // need to wait for new data + } + } + + // Prune memo entries that have expired (to bound memory consumption) + if (!lhs.Empty()) { + for (size_t i = 1; i < state_.size(); ++i) { + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - + options_.tolerance); + } + } + + // Emit the batch + if (dst.empty()) { + return NULLPTR; + } else { + return dst.Materialize(output_schema(), state_); + } + } + + void Process() { + std::lock_guard guard(gate_); + if (finished_.is_finished()) { + return; + } + + // Process batches while we have data + for (;;) { + Result> result = ProcessInner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ++batches_produced_; + ExecBatch out_b(*out_rb); + outputs_[0]->InputReceived(this, std::move(out_b)); + } else { + StopProducing(); + ErrorIfNotOk(result.status()); + return; + } + } + + // Report to the output the total batch count, if we've already finished everything + // (there are two places where this can happen: here and InputFinished) + // + // It may happen here in cases where InputFinished was called before we were finished + // producing results (so we didn't know the output size at that time) + if (state_.at(0)->Finished()) { + StopProducing(); + outputs_[0]->InputFinished(this, batches_produced_); + } + } + + void ProcessThread() { + for (;;) { + if (!process_.Pop()) { + return; + } + Process(); + } + } + + static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } + + public: + AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, + const AsofJoinNodeOptions& join_options, + std::shared_ptr output_schema); + + virtual ~AsofJoinNode() { + process_.Push(false); // poison pill + process_thread_.join(); + } + + static arrow::Result> MakeOutputSchema( + const std::vector& inputs, const AsofJoinNodeOptions& options) { + std::vector> fields; + + const auto& on_field_name = *options.on_key.name(); + const auto& by_field_name = *options.by_key.name(); + + // Take all non-key, non-time RHS fields + for (size_t j = 0; j < inputs.size(); ++j) { + const auto& input_schema = inputs[j]->output_schema(); + const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); + const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + + if ((on_field_ix == -1) | (by_field_ix == -1)) { + return Status::Invalid("Missing join key on table ", j); + } + + for (int i = 0; i < input_schema->num_fields(); ++i) { + const auto field = input_schema->field(i); + if (field->name() == on_field_name) { + if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { + return Status::Invalid("Unsupported type for on key: ", field->name()); + } + // Only add on field from the left table + if (j == 0) { + fields.push_back(field); + } + } else if (field->name() == by_field_name) { + if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { + return Status::Invalid("Unsupported type for by key: ", field->name()); + } + // Only add by field from the left table + if (j == 0) { + fields.push_back(field); + } + } else { + if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { + return Status::Invalid("Unsupported data type: ", field->name()); + } + + fields.push_back(field); + } + } + } + return std::make_shared(fields); + } + + static arrow::Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; + + const auto& join_options = checked_cast(options); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, join_options)); + + std::vector input_labels(inputs.size()); + input_labels[0] = "left"; + for (size_t i = 1; i < inputs.size(); ++i) { + input_labels[i] = "right_" + std::to_string(i); + } + + return plan->EmplaceNode(plan, inputs, std::move(input_labels), + join_options, std::move(output_schema)); + } + + const char* kind_name() const override { return "AsofJoinNode"; } + + void InputReceived(ExecNode* input, ExecBatch batch) override { + // Get the input + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + + // Put into the queue + auto rb = *batch.ToRecordBatch(input->output_schema()); + state_.at(k)->Push(rb); + process_.Push(true); + } + void ErrorReceived(ExecNode* input, Status error) override { + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + void InputFinished(ExecNode* input, int total_batches) override { + { + std::lock_guard guard(gate_); + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + state_.at(k)->set_total_batches(total_batches); + } + // Trigger a process call + // The reason for this is that there are cases at the end of a table where we don't + // know whether the RHS of the join is up-to-date until we know that the table is + // finished. + process_.Push(true); + } + Status StartProducing() override { + finished_ = arrow::Future<>::Make(); + return Status::OK(); + } + void PauseProducing(ExecNode* output, int32_t counter) override {} + void ResumeProducing(ExecNode* output, int32_t counter) override {} + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + StopProducing(); + } + void StopProducing() override { finished_.MarkFinished(); } + arrow::Future<> finished() override { return finished_; } + + private: + static const std::set> kSupportedOnTypes_; + static const std::set> kSupportedByTypes_; + static const std::set> kSupportedDataTypes_; + + arrow::Future<> finished_; + // InputStates + // Each input state correponds to an input table + std::vector> state_; + std::mutex gate_; + AsofJoinNodeOptions options_; + + // Queue for triggering processing of a given input + // (a false value is a poison pill) + ConcurrentQueue process_; + // Worker thread + std::thread process_thread_; + + // In-progress batches produced + int batches_produced_ = 0; +}; + +AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, + std::vector input_labels, + const AsofJoinNodeOptions& join_options, + std::shared_ptr output_schema) + : ExecNode(plan, inputs, input_labels, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + options_(join_options), + process_(), + process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { + for (size_t i = 0; i < inputs.size(); ++i) + state_.push_back(::arrow::internal::make_unique( + inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + finished_ = arrow::Future<>::MakeFinished(); +} + +// Currently supported types +const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; +const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; +const std::set> AsofJoinNode::kSupportedDataTypes_ = { + int32(), int64(), float32(), float64()}; + +namespace internal { +void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); +} +} // namespace internal + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc new file mode 100644 index 00000000000..8b993764abe --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/testing/random.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +void CheckRunOutput(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r0_batches, + const BatchesWithSchema& r1_batches, + const BatchesWithSchema& exp_batches, const FieldRef time, + const FieldRef keys, const int64_t tolerance) { + auto exec_ctx = + arrow::internal::make_unique(default_memory_pool(), nullptr); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + AsofJoinNodeOptions join_options(time, keys, tolerance); + Declaration join{"asofjoin", join_options}; + + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r0_batches.schema, r0_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r1_batches.schema, r1_batches.gen(false, false)}}); + + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + ASSERT_OK_AND_ASSIGN(auto exp_table, + TableFromExecBatches(exp_batches.schema, exp_batches.batches)); + + ASSERT_OK_AND_ASSIGN(auto res_table, TableFromExecBatches(exp_batches.schema, res)); + + AssertTablesEqual(*exp_table, *res_table, + /*same_chunk_layout=*/true, /*flatten=*/true); +} + +void DoRunBasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_data, int64_t tolerance) { + auto l_schema = + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); + auto r0_schema = + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); + auto r1_schema = + schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); + + auto exp_schema = schema({ + field("time", int64()), + field("key", int32()), + field("l_v0", float64()), + field("r0_v0", float64()), + field("r1_v0", float32()), + }); + + // Test three table join + BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; + l_batches = MakeBatchesFromString(l_schema, l_data); + r0_batches = MakeBatchesFromString(r0_schema, r0_data); + r1_batches = MakeBatchesFromString(r1_schema, r1_data); + exp_batches = MakeBatchesFromString(exp_schema, exp_data); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); +} + +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); + BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); + + ExecContext exec_ctx; + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); + + AsofJoinNodeOptions join_options("time", "key", 0); + Declaration join{"asofjoin", join_options}; + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); + + ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); +} + +class AsofJoinTest : public testing::Test {}; + +TEST(AsofJoinTest, TestBasic1) { + // Single key, single batch + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0]])"}, + /*r1*/ {R"([[1000, 1, 101.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); +} + +TEST(AsofJoinTest, TestBasic2) { + // Single key, multiple batches + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, + /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); +} + +TEST(AsofJoinTest, TestBasic3) { + // Single key, multiple left batches, single right batches + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, + /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); +} + +TEST(AsofJoinTest, TestBasic4) { + // Multi key, multiple batches, misaligned batches + DoRunBasicTest( + /*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestBasic5) { + // Multi key, multiple batches, misaligned batches, smaller tolerance + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 500); +} + +TEST(AsofJoinTest, TestBasic6) { + // Multi key, multiple batches, misaligned batches, zero tolerance + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, + 0); +} + +TEST(AsofJoinTest, TestEmpty1) { + // Empty left batch + DoRunBasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty2) { + // Empty left input + DoRunBasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([])"}, 1000); +} + +TEST(AsofJoinTest, TestEmpty3) { + // Empty right batch + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty4) { + // Empty right input + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", + R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty5) { + // All empty + DoRunBasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} + +TEST(AsofJoinTest, TestUnsupportedOntype) { + DoRunInvalidTypeTest( + schema({field("time", utf8()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", utf8()), field("key", int32()), field("r0_v0", float32())})); +} + +TEST(AsofJoinTest, TestUnsupportedBytype) { + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key", utf8()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", utf8()), field("r0_v0", float32())})); +} + +TEST(AsofJoinTest, TestUnsupportedDatatype) { + // Utf8 is unsupported + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); +} + +TEST(AsofJoinTest, TestMissingKeys) { + DoRunInvalidTypeTest( + schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), + schema( + {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); + + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), + schema( + {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b7a9c7e1bb0..95e8953065e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -531,6 +531,7 @@ void RegisterUnionNode(ExecFactoryRegistry*); void RegisterAggregateNode(ExecFactoryRegistry*); void RegisterSinkNode(ExecFactoryRegistry*); void RegisterHashJoinNode(ExecFactoryRegistry*); +void RegisterAsofJoinNode(ExecFactoryRegistry*); } // namespace internal @@ -545,6 +546,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterAggregateNode(this); internal::RegisterSinkNode(this); internal::RegisterHashJoinNode(this); + internal::RegisterAsofJoinNode(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 3c901be0d2e..355b9083b03 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -361,6 +361,35 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { Expression filter; }; +/// \brief Make a node which implements asof join operation +/// +/// Note, this API is experimental and will change in the future +/// +/// This node takes one left table and any number of right tables, and asof joins them +/// together. Batches produced by each input must be ordered by the "on" key. +/// This node will output one row for each row in the left table. +class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { + public: + AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + + /// \brief "on" key for the join. Each + /// + /// All inputs tables must be sorted by the "on" key. Inexact + /// match is used on the "on" key. i.e., a row is considiered match iff + /// left_on - tolerance <= right_on <= left_on. + /// Currently, "on" key must be an int64 field + FieldRef on_key; + /// \brief "by" key for the join. + /// + /// All input tables must have the "by" key. Exact equality + /// is used for the "by" key. + /// Currently, the "by" key must be an int32 field + FieldRef by_key; + /// Tolerance for inexact "on" key matching + int64_t tolerance; +}; + /// \brief Make a node which select top_k/bottom_k rows passed through it /// /// All batches pushed to this node will be accumulated, then selected, by the given diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 41eb401ced6..2eabc1d1c26 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -221,6 +221,30 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, return out; } +BatchesWithSchema MakeBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity) { + BatchesWithSchema out_batches{{}, schema}; + + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches.batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches.batches.push_back(out_batches.batches[i]); + } + } + + return out_batches; +} + Result> SortTableOnAllFields(const std::shared_ptr& tab) { std::vector sort_keys; for (auto&& f : tab->schema()->fields()) { diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 9347d1343f1..c69f1b94446 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -96,6 +96,11 @@ ARROW_TESTING_EXPORT BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, int num_batches = 10, int batch_size = 4); +ARROW_TESTING_EXPORT +BatchesWithSchema MakeBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1); + ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab);