From 23b8c71ad96a6aea3b5be5dd01500291d3112309 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 18 Apr 2022 12:01:19 -0400 Subject: [PATCH 01/47] wip --- cpp/src/arrow/compute/exec/asof_join.cc | 33 +++++ cpp/src/arrow/compute/exec/asof_join.h | 50 ++++++++ cpp/src/arrow/compute/exec/asof_join_node.cc | 115 ++++++++++++++++++ .../arrow/compute/exec/asof_join_node_test.cc | 84 +++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 cpp/src/arrow/compute/exec/asof_join.cc create mode 100644 cpp/src/arrow/compute/exec/asof_join.h create mode 100644 cpp/src/arrow/compute/exec/asof_join_node.cc create mode 100644 cpp/src/arrow/compute/exec/asof_join_node_test.cc diff --git a/cpp/src/arrow/compute/exec/asof_join.cc b/cpp/src/arrow/compute/exec/asof_join.cc new file mode 100644 index 00000000000..9413cccf4b6 --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join.cc @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/asof_join.h" + +namespace arrow { +namespace compute { + +class AsofJoinBasicImpl : public AsofJoinImpl { + +}; + +Result> AsofJoinImpl::MakeBasic() { + std::unique_ptr impl{new AsofJoinBasicImpl()}; + return std::move(impl); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h new file mode 100644 index 00000000000..bace2419f2e --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/task_util.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/tracing_internal.h" + +namespace arrow { +namespace compute { + +class AsofJoinSchema { + public: + std::shared_ptr MakeOutputSchema(const std::vector& inputs, + const AsofJoinNodeOptions& options + ); + +}; + +class AsofJoinImpl { + public: + static Result> MakeBasic(); + +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc new file mode 100644 index 00000000000..86122c0c839 --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/asof_join.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { +namespace compute { + +class AsofJoinNode : public ExecNode { + public: + AsofJoinNode(ExecPlan* plan, + NodeVector inputs, + const AsofJoinNodeOptions& join_options, + std::shared_ptr output_schema, + std::unique_ptr schema_mgr, + std::unique_ptr impl + ) + : ExecNode(plan, inputs, {"left", "right"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + impl_(std::move(impl)) { + complete_.store(false); + } + + static arrow::Result Make(ExecPlan *plan, std::vector inputs, + const ExecNodeOptions &options) { + std::unique_ptr schema_mgr = + ::arrow::internal::make_unique(); + + const auto& join_options = checked_cast(options); + std::shared_ptr output_schema = schema_mgr->MakeOutputSchema(inputs, join_options); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, AsofJoinImpl::MakeBasic()); + + return plan->EmplaceNode( + plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr), + std::move(impl) + ); + } + + const char* kind_name() const override { return "AsofJoinNode"; } + + void InputReceived(ExecNode* input, ExecBatch batch) override {} + void ErrorReceived(ExecNode* input, Status error) override {} + void InputFinished(ExecNode* input, int total_batches) override {} + Status StartProducing() override { return Status::OK();} + void PauseProducing(ExecNode* output) override {} + void ResumeProducing(ExecNode* output) override {} + void StopProducing(ExecNode* output) override {} + void StopProducing() override {} + Future<> finished() override { return finished_; } + + private: + std::atomic complete_; + std::unique_ptr schema_mgr_; + std::unique_ptr impl_; +}; + +std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector& inputs, + const AsofJoinNodeOptions& options) { + std::vector> fields; + assert(inputs.size() > 1); + + std::vector keys(options.keys.size()); + + // Directly map LHS fields + for(int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) + fields.push_back(inputs[0]->output_schema()->field(i)); + + // Take all non-key, non-time RHS fields + for(size_t j = 1; j < inputs.size(); ++j) { + const auto &input_schema = inputs[j]->output_schema(); + for(int i = 0; i < input_schema->num_fields(); ++i) { + const auto &name = input_schema->field(i)->name(); + if((std::find(keys.begin(), keys.end(), name) != keys.end()) && (name!= *options.time.name())) { + fields.push_back(input_schema->field(i)); + } + } + } + + // Combine into a schema + return std::make_shared(fields); +} + +namespace internal { + void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc new file mode 100644 index 00000000000..993cd3882f6 --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/testing/random.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +BatchesWithSchema GenerateBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + BatchesWithSchema out_batches{{}, schema}; + + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches.batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches.batches.push_back(out_batches.batches[i]); + } + } + + return out_batches; +} + +void RunNonEmptyTest() { + auto l_schema = schema( + { + field("time", timestamp64(TimeUnit::NANO)), + field("key", int32()), + field("l_v0", float32()) + } + ); + auto r_schema = schema( + { + field("time", timestamp64(TimeUnit::NANO)), + field("key", int32()), + field("r_v0", float32()) + } + ); +} + +} // namespace compute +} // namespace arrow From f4b21067941c7d5c60983b9266dd550ee6e58633 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 18 Apr 2022 16:19:54 -0400 Subject: [PATCH 02/47] wip --- cpp/src/arrow/CMakeLists.txt | 2 + cpp/src/arrow/compute/exec/CMakeLists.txt | 5 ++ .../arrow/compute/exec/asof_join_node_test.cc | 58 ++++++++++++++++++- cpp/src/arrow/compute/exec/exec_plan.cc | 1 + cpp/src/arrow/compute/exec/options.h | 12 ++++ 5 files changed, 75 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 3b49f4ca00a..0a098df410f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -381,6 +381,8 @@ if(ARROW_COMPUTE) compute/cast.cc compute/exec.cc compute/exec/aggregate_node.cc + compute/exec/asof_join.cc + compute/exec/asof_join_node.cc compute/exec/bloom_filter.cc compute/exec/exec_plan.cc compute/exec/expression.cc diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index b2a21c2bd6b..f19921f4f28 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -32,6 +32,11 @@ add_arrow_compute_test(hash_join_node_test hash_join_node_test.cc bloom_filter_test.cc key_hash_test.cc) +add_arrow_compute_test(asof_join_node_test + PREFIX + "arrow-compute" + SOURCES + asof_join_node_test.cc) add_arrow_compute_test(tpch_node_test PREFIX "arrow-compute") add_arrow_compute_test(union_node_test PREFIX "arrow-compute") add_arrow_compute_test(util_test PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 993cd3882f6..0c8d27083b0 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -63,21 +63,73 @@ BatchesWithSchema GenerateBatchesFromString( return out_batches; } -void RunNonEmptyTest() { +void CheckRunOutput(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const FieldRef time, + const std::vector& keys) { + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), nullptr + ); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + AsofJoinNodeOptions join_options{time, keys}; + Declaration join{"asofjoin", join_options}; + + join.inputs.emplace_back( + Declaration{"source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); + join.inputs.emplace_back( + Declaration{"source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); + + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + // ASSERT_FNISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); +} + +void RunNonEmptyTest(bool exact_matches) { auto l_schema = schema( { - field("time", timestamp64(TimeUnit::NANO)), + field("time", timestamp(TimeUnit::NANO)), field("key", int32()), field("l_v0", float32()) } ); auto r_schema = schema( { - field("time", timestamp64(TimeUnit::NANO)), + field("time", timestamp(TimeUnit::NANO)), field("key", int32()), field("r_v0", float32()) } ); + + BatchesWithSchema l_batches, r_batches, exp_batches; + + l_batches = GenerateBatchesFromString( + l_schema, + {R"([["2020-01-01", 0, 1.0]])"} + + ); + r_batches = GenerateBatchesFromString( + l_schema, + {R"([["2020-01-01", 0, 2.0]])"} + + ); + + CheckRunOutput(l_batches, r_batches, "time", /*keys=*/{{"key"}}); +} + + class AsofJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P( + AsofJoinTest, AsofJoinTest, + ::testing::Combine( + ::testing::Values(false, true) + )); + +TEST_P(AsofJoinTest, TestExactMatches) { + RunNonEmptyTest(std::get<0>(GetParam())); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b7a9c7e1bb0..7c7ae930308 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -531,6 +531,7 @@ void RegisterUnionNode(ExecFactoryRegistry*); void RegisterAggregateNode(ExecFactoryRegistry*); void RegisterSinkNode(ExecFactoryRegistry*); void RegisterHashJoinNode(ExecFactoryRegistry*); +void RegisterAsofJoinNode(ExecFactoryRegistry*); } // namespace internal diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 3c901be0d2e..26933e2d757 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -361,6 +361,18 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { Expression filter; }; + +class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { + public: + AsofJoinNodeOptions(FieldRef time, std::vector keys) : + time(std::move(time)), keys(std::move(keys)) {} + + // time column + FieldRef time; + // keys used for the join. All tables must have the same join key. + std::vector keys; +}; + /// \brief Make a node which select top_k/bottom_k rows passed through it /// /// All batches pushed to this node will be accumulated, then selected, by the given From 7ab446d02a436fde37dba60e243d5067cf9ff058 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 21 Apr 2022 11:36:03 -0400 Subject: [PATCH 03/47] wip --- cpp/src/arrow/compute/exec/asof_join.cc | 31 + cpp/src/arrow/compute/exec/asof_join.h | 26 + cpp/src/arrow/compute/exec/asof_join_node.cc | 648 +++++++++++++++++- .../arrow/compute/exec/asof_join_node_test.cc | 14 +- .../compute/exec/concurrent_bounded_queue.h | 72 ++ cpp/src/arrow/compute/exec/exec_plan.cc | 1 + cpp/src/arrow/compute/exec/options.h | 8 +- 7 files changed, 775 insertions(+), 25 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/concurrent_bounded_queue.h diff --git a/cpp/src/arrow/compute/exec/asof_join.cc b/cpp/src/arrow/compute/exec/asof_join.cc index 9413cccf4b6..d70bce213be 100644 --- a/cpp/src/arrow/compute/exec/asof_join.cc +++ b/cpp/src/arrow/compute/exec/asof_join.cc @@ -16,10 +16,41 @@ // under the License. #include "arrow/compute/exec/asof_join.h" +#include + +#include +#include +#include +#include +#include +//#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // so we don't need to require C++20 +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "concurrent_bounded_queue.h" namespace arrow { namespace compute { + + + class AsofJoinBasicImpl : public AsofJoinImpl { }; diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index bace2419f2e..b23c8aef1ec 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -21,6 +21,8 @@ #include #include +#include +#include #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/task_util.h" @@ -28,10 +30,34 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/tracing_internal.h" +#include +#include +#include +#include // so we don't need to require C++20 + +#include "concurrent_bounded_queue.h" namespace arrow { namespace compute { +typedef int32_t KeyType; + +// Maximum number of tables that can be joined +#define MAX_JOIN_TABLES 6 + +// Capacity of the input queues (for flow control) +// Why 2? +// It needs to be at least 1 to enable progress (otherwise queues have no capacity) +// It needs to be at least 2 to enable addition of a new queue entry while processing +// is being done for another input. +// There's no clear performance benefit to greater than 2. +#define QUEUE_CAPACITY 2 + +// The max rows per batch is dictated by the data type for row index +#define MAX_ROWS_PER_BATCH 0xFFFFFFFF +typedef uint32_t row_index_t; +typedef int col_index_t; + class AsofJoinSchema { public: std::shared_ptr MakeOutputSchema(const std::vector& inputs, diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 86122c0c839..321e72647e4 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include "arrow/compute/exec/exec_plan.h" @@ -27,10 +28,542 @@ #include "arrow/util/make_unique.h" #include "arrow/util/thread_pool.h" +#include +#include +#include +#include +#include +//#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // so we don't need to require C++20 +#include +#include +#include +#include +#include +#include +#include + +#include + + +#include "concurrent_bounded_queue.h" + namespace arrow { namespace compute { +struct MemoStore { + struct Entry { + // Timestamp associated with the entry + int64_t _time; + + // Batch associated with the entry (perf is probably OK for this; batches change rarely) + std::shared_ptr _batch; + + // Row associated with the entry + row_index_t _row; + }; + + std::unordered_map _entries; + + void store(const std::shared_ptr &batch, row_index_t row, int64_t time, KeyType key) { + auto &e=_entries[key]; + // that we can do this assignment optionally, is why we + // can get array with using shared_ptr above (the batch + // shouldn't change that often) + if(e._batch!=batch) e._batch=batch; + e._row=row; + e._time=time; + } + + util::optional get_entry_for_key(KeyType key) const { + auto e=_entries.find(key); + if(_entries.end()==e) + return util::nullopt; + return util::optional(&e->second); + } + + void remove_entries_with_lesser_time(int64_t ts) { + size_t dbg_size0=_entries.size(); + for(auto e=_entries.begin();e!=_entries.end();) + if(e->second._time > _queue; + + // Wildcard key for this input, if applicable. + util::optional _wildcard_key; + + // Schema associated with the input + std::shared_ptr _schema; + + // Total number of batches (only int because InputFinished uses int) + int _total_batches=-1; + + // Number of batches processed so far (only int because InputFinished uses int) + int _batches_processed=0; + + // Index of the time col + col_index_t _time_col_index; + + // Index of the key col + col_index_t _key_col_index; + + // Index of the latest row reference within; if >0 then _queue cannot be empty + row_index_t _latest_ref_row=0; // must be < _queue.front()->num_rows() if _queue is non-empty + + // Stores latest known values for the various keys + MemoStore _memo; + + // Mapping of source columns to destination columns + std::vector> src_to_dst; + +public: + InputState(const std::shared_ptr &schema, + const std::string &time_col_name, + const std::string &key_col_name, + util::optional wildcard_key) + : + _queue(QUEUE_CAPACITY), + _wildcard_key(wildcard_key), + _schema(schema), + _time_col_index(schema->GetFieldIndex(time_col_name)), //TODO: handle missing field name + _key_col_index(schema->GetFieldIndex(key_col_name)) //TODO: handle missing field name + { /*nothing else*/ } + + size_t init_src_to_dst_mapping(size_t dst_offset,bool skip_time_and_key_fields) { + src_to_dst.resize(_schema->num_fields()); + for(int i=0;i<_schema->num_fields();++i) + if(!(skip_time_and_key_fields && is_time_or_key_column(i))) + src_to_dst[i]=dst_offset++; + return dst_offset; + } + + const util::optional& map_src_to_dst(col_index_t src)const { + return src_to_dst[src]; + } + + bool is_time_or_key_column(col_index_t i) const { + assert(i<_schema->num_fields()); + return (i==_time_col_index) || (i==_key_col_index); + } + + // Gets the latest row index, assuming the queue isn't empty + row_index_t get_latest_row() const { + return _latest_ref_row; + } + + bool empty() const { + if(_latest_ref_row>0) return false; // cannot be empty if ref row is >0 -- can avoid slow queue lock below + return _queue.empty(); + } + + int count_batches_processed()const { return _batches_processed; } + int count_total_batches()const { return _total_batches; } + + // Gets latest batch (precondition: must not be empty) + const std::shared_ptr &get_latest_batch()const { + return _queue.unsync_front(); + } + KeyType get_latest_key() const { + return _queue.unsync_front()->column_data(_key_col_index)->GetValues(1)[_latest_ref_row]; + } + int64_t get_latest_time()const { + return _queue.unsync_front()->column_data(_time_col_index)->GetValues(1)[_latest_ref_row]; + } + + bool finished() const { + return _batches_processed==_total_batches; + } + + bool advance() { + // Returns true if able to advance, false if not. + bool have_active_batch=(_latest_ref_row>0/*short circuit the lock on the queue*/) || !_queue.empty(); + if(have_active_batch) { + // If we have an active batch + if(++_latest_ref_row>=_queue.unsync_front()->num_rows()) { + // hit the end of the batch, need to get the next batch if possible. + ++_batches_processed; + _latest_ref_row=0; + have_active_batch&=!_queue.try_pop(); + if(have_active_batch) assert(_queue.unsync_front()->num_rows()>0); // empty batches disallowed + } + } + return have_active_batch; + } + + // Advance the data to be immediately past the specified TS, updating latest and latest_ref_row to + // the latest data prior to that immediate just past + // Returns true if updates were made, false if not. + bool advance_and_memoize(int64_t ts) { + // Check if already updated for TS (or if there is no latest) + if(empty()) return false; // can't advance if empty + auto latest_time=get_latest_time(); + if(latest_time>ts) return false; // already advanced + + // Not updated. Try to update and possibly advance. + bool updated=false; + do { + latest_time=get_latest_time(); + if(latest_time<=ts) // if advance() returns true, then the latest_ts must also be valid + _memo.store(get_latest_batch(),_latest_ref_row,latest_time,get_latest_key()); + else + break; // hit a future timestamp -- done updating for now + updated=true; + } while (advance()); + return updated; + } + + void push(const std::shared_ptr &rb) { + if(rb->num_rows()>0) { + _queue.push(rb); + } else { + ++_batches_processed; // don't enqueue empty batches, just record as processed + } + } + + util::optional get_memo_entry_for_key(KeyType key) { + auto r=_memo.get_entry_for_key(key); + if(r.has_value()) return r; + if(_wildcard_key.has_value()) + r=_memo.get_entry_for_key(*_wildcard_key); + return r; + } + + util::optional get_memo_time_for_key(KeyType key) { + auto r=get_memo_entry_for_key(key); + return r.has_value()?util::make_optional((*r)->_time):util::nullopt; + } + + void remove_memo_entries_with_lesser_time(int64_t ts) { + _memo.remove_entries_with_lesser_time(ts); + } + + const std::shared_ptr &get_schema() const { return _schema; } + + void set_total_batches(int n) { + assert(n>=0); // not sure why arrow uses a signed int for this, but it should be >=0 + assert(_total_batches==-1); // shouldn't be set more than once + _total_batches=n; + } +}; + +template +struct CompositeReferenceRow +{ + struct Entry + { + arrow::RecordBatch *_batch; // can be NULL if there's no value + row_index_t _row; + }; + Entry _refs[MAX_TABLES]; +}; + +// A table of composite reference rows. Rows maintain pointers to the +// constituent record batches, but the overall table retains shared_ptr +// references to ensure memory remains resident while the table is live. +// +// The main reason for this is that, especially for wide tables, joins +// are effectively row-oriented, rather than column-oriented. Separating +// the join part from the columnar materialization part simplifies the +// logic around data types and increases efficiency. +// +// We don't put the shared_ptr's into the rows for efficiency reasons. +template +class CompositeReferenceTable { + // Contains shared_ptr refs for all RecordBatches referred to by the contents of _rows + std::unordered_map> _ptr2ref; + + // Row table references + std::vector> _rows; + + // Total number of tables in the composite table + size_t _n_tables; + + // Adds a RecordBatch ref to the mapping, if needed + void add_record_batch_ref(const std::shared_ptr &ref) { + if(!_ptr2ref.count((uintptr_t)ref.get())) + _ptr2ref[(uintptr_t)ref.get()]=ref; + } +public: + CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) + { + assert(_n_tables>=1); + assert(_n_tables<=MAX_TABLES); + } + + size_t n_rows() const { return _rows.size(); } + + // Adds the latest row from the input state as a new composite reference row + // - LHS must have a valid key,timestep,and latest rows + // - RHS must have valid data memo'ed for the key + void emplace(std::vector> &in,int64_t tolerance) { + assert(in.size()==_n_tables); + + // Get the LHS key + KeyType key=in[0]->get_latest_key(); + + // Add row and setup LHS + // (the LHS state comes just from the latest row of the LHS table) + assert(!in[0]->empty()); + const std::shared_ptr &lhs_latest_batch=in[0]->get_latest_batch(); + row_index_t lhs_latest_row=in[0]->get_latest_row(); + int64_t lhs_latest_time=in[0]->get_latest_time(); + if(0==lhs_latest_row) { + // On the first row of the batch, we resize the destination. + // The destination size is dictated by the size of the LHS batch. + assert(lhs_latest_batch->num_rows()<=MAX_ROWS_PER_BATCH); //TODO: better error handling + row_index_t new_batch_size=lhs_latest_batch->num_rows(); + row_index_t new_capacity=_rows.size()+new_batch_size; + if(_rows.capacity() materialize(const std::shared_ptr &output_schema, + const std::vector> &state) { + //cerr << "materialize BEGIN\n"; + assert(state.size()==_n_tables); + assert(state.size()>=1); + + // Don't build empty batches + size_t n_rows=_rows.size(); + if(!n_rows) return nullptr; + + // Count output columns (dbg sanitycheck) + { + int n_out_cols=0; + for(const auto &s:state) + n_out_cols+=s->get_schema()->num_fields(); + n_out_cols-=(state.size()-1)*2; // remove column indices for key and time cols on RHS + assert(n_out_cols==output_schema->num_fields()); + } + + // Instance the types we support + std::shared_ptr i32_type=arrow::int32(); + std::shared_ptr i64_type=arrow::int64(); + std::shared_ptr f64_type=arrow::float64(); + + // Build the arrays column-by-column from our rows + std::vector> arrays(output_schema->num_fields()); + for(size_t i_table=0;i_table<_n_tables;++i_table) { + int n_src_cols=state.at(i_table)->get_schema()->num_fields(); + { + for(col_index_t i_src_col=0;i_src_col i_dst_col_opt=state[i_table]->map_src_to_dst(i_src_col); + if(!i_dst_col_opt) continue; + col_index_t i_dst_col=*i_dst_col_opt; + const auto &src_field=state[i_table]->get_schema()->field(i_src_col); + const auto &dst_field=output_schema->field(i_dst_col); + assert(src_field->type()->Equals(dst_field->type())); + assert(src_field->name()==dst_field->name()); + const auto &field_type=src_field->type(); + if(field_type->Equals(i32_type)) { + arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); + } else if(field_type->Equals(i64_type)) { + arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); + } else if(field_type->Equals(f64_type)) { + arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); + } else { + std::cerr << "Unsupported data type: " << field_type->name() << "\n"; + exit(-1); // TODO: validate elsewhere for better error handling + } + } + } + } + + // Build the result + assert(sizeof(size_t)>=sizeof(int64_t)); // Make takes signed int64_t for num_rows + //TODO: check n_rows for cast + std::shared_ptr r=arrow::RecordBatch::Make(output_schema,(int64_t)n_rows,arrays); + //cerr << "materialize END (ndstrows="<< (r?r->num_rows():-1) <<")\n"; + return r; + } + + // Returns true if there are no rows + bool empty() const { + return _rows.empty(); + } +}; + class AsofJoinNode : public ExecNode { + + // Constructs labels for inputs + static std::vector build_input_labels(const std::vector &inputs) { + std::vector r(inputs.size()); + for(size_t i=0;iadvance_and_memoize(lhs_latest_time); + return any_updated; + } + + // Returns false if RHS not up to date for LHS + bool is_up_to_date_for_lhs_row() const { + auto &lhs=*_state[0]; + if(lhs.empty()) return false; // can't proceed if nothing on the LHS + int64_t lhs_ts=lhs.get_latest_time(); + for(size_t i=1;i<_state.size();++i) + { + auto &rhs=*_state[i]; + if(!rhs.finished()) { + // If RHS is finished, then we know it's up to date (but if it isn't, it might be up to date) + if(rhs.empty()) return false; // RHS isn't finished, but is empty --> not up to date + if(lhs_ts>=rhs.get_latest_time()) return false; // TS not up to date (and not finished) + } + } + return true; + } + + std::shared_ptr process_inner() { + + assert(!_state.empty()); + auto &lhs=*_state.at(0); + + // Construct new target table if needed + CompositeReferenceTable dst(_state.size()); + + // Generate rows into the dst table until we either run out of data or hit the row limit, or run out of input + for(;;) { + // If LHS is finished or empty then there's nothing we can do here + if(lhs.finished() || lhs.empty()) break; + + // Advance each of the RHS as far as possible to be up to date for the LHS timestep + bool any_advanced=update_rhs(); + + // Only update if we have up-to-date information for the LHS row + if(is_up_to_date_for_lhs_row()) { + dst.emplace(_state,_options._tolerance); + if(!lhs.advance()) break; // if we can't advance LHS, we're done for this batch + } else { + if((!any_advanced) && (_state.size()>1)) break; // need to wait for new data + } + } + + // Prune memo entries that have expired (to bound memory consumption) + if(!lhs.empty()) + for(size_t i=1;i<_state.size();++i) + _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time()-_options._tolerance); + + // Emit the batch + std::shared_ptr r=dst.empty()?nullptr:dst.materialize(output_schema(),_state); + return r; + } + + void process() { + // std::lock_guard guard(_gate); + if(finished_.is_finished()) { std::cerr << "InputReceived EARLYEND\n"; return; } + + // Process batches while we have data + for(;;) { + std::shared_ptr out_rb=process_inner(); + if(!out_rb) break; + ++_progress_batches_produced; + ExecBatch out_b(*out_rb); + outputs_[0]->InputReceived(this,std::move(out_b)); + } + + // Report to the output the total batch count, if we've already finished everything + // (there are two places where this can happen: here and InputFinished) + // + // It may happen here in cases where InputFinished was called before we were finished + // producing results (so we didn't know the output size at that time) + if(_state.at(0)->finished()) { + //cerr << "LHS is finished\n"; + _total_batches_produced=util::make_optional(_progress_batches_produced); + StopProducing(); + assert(_total_batches_produced.has_value()); + outputs_[0]->InputFinished(this,*_total_batches_produced); + } + } + + void process_thread() { + std::cerr << "AsOfJoinNode::process_thread started.\n"; + for(;;) { + if(!_process.pop()) { + std::cerr << "AsOfJoinNode::process_thread done.\n"; + return; + } + process(); + } + } + + static void process_thread_wrapper(AsofJoinNode *node) { + node->process_thread(); + } + public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, @@ -42,8 +575,23 @@ class AsofJoinNode : public ExecNode { : ExecNode(plan, inputs, {"left", "right"}, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - impl_(std::move(impl)) { - complete_.store(false); + impl_(std::move(impl)), + _options(join_options), + _process(1), + _process_thread(&AsOfJoinNode::process_thread_wrapper, this) + { + std::cout << "AsofJoinNode created" << "\n"; + + for(size_t i=0;i(inputs[i]->output_schema(), + *_options.time.name(), + *_options.keys.name(), + util::make_optional(0) /*TODO: make wildcard configuirable*/)); + size_t dst_offset=0; + for(auto &state:_state) + dst_offset=state->init_src_to_dst_mapping(dst_offset,!!dst_offset); + + finished_ = Future<>::MakeFinished(); } static arrow::Result Make(ExecPlan *plan, std::vector inputs, @@ -63,20 +611,88 @@ class AsofJoinNode : public ExecNode { const char* kind_name() const override { return "AsofJoinNode"; } - void InputReceived(ExecNode* input, ExecBatch batch) override {} - void ErrorReceived(ExecNode* input, Status error) override {} - void InputFinished(ExecNode* input, int total_batches) override {} - Status StartProducing() override { return Status::OK();} - void PauseProducing(ExecNode* output) override {} - void ResumeProducing(ExecNode* output) override {} - void StopProducing(ExecNode* output) override {} - void StopProducing() override {} - Future<> finished() override { return finished_; } + void InputReceived(ExecNode* input, ExecBatch batch) override { + // Get the input + ARROW_DCHECK(std::find(inputs_.begin(),inputs_.end(),input)!=inputs_.end()); + size_t k=std::find(inputs_.begin(),inputs_.end(),input)-inputs_.begin(); + std::cerr << "InputReceived BEGIN (k="<output_schema()); + + _state.at(k)->push(rb); + _process.push(true); + + std::cerr << "InputReceived END\n"; + } + void ErrorReceived(ExecNode* input, Status error) override { + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + void InputFinished(ExecNode* input, int total_batches) override { + std::cerr << "InputFinished BEGIN\n"; + bool is_finished=false; + { + ARROW_DCHECK(std::find(inputs_.begin(),inputs_.end(),input)!=inputs_.end()); + size_t k=std::find(inputs_.begin(),inputs_.end(),input)-inputs_.begin(); + //cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; + _state.at(k)->set_total_batches(total_batches); + is_finished=_state.at(k)->finished(); + } + // Trigger a process call + // The reason for this is that there are cases at the end of a table where we don't + // know whether the RHS of the join is up-to-date until we know that the table is + // finished. + _process.push(true); + + std::cerr << "InputFinished END\n"; + } + Status StartProducing() override { + finished_=arrow::Future<>::Make(); + std::cout << "StartProducing" << "\n"; + return Status::OK(); + } + void PauseProducing(ExecNode* output) override { + std::cout << "PauseProducing" << "\n"; + } + void ResumeProducing(ExecNode* output) override { + std::cout << "ResumeProducing" << "\n"; + } + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output,outputs_[0]); + StopProducing(); + std::cout << "StopProducing" << "\n"; + } + void StopProducing() override { + //if(batch_count_.Cancel()) finished_.MarkFinished(); + finished_.MarkFinished(); + for(auto&& input: inputs_) input->StopProducing(this); + } + Future<> finished() override { + std::cout << "finished" << "\n"; + return finished_; + } private: - std::atomic complete_; std::unique_ptr schema_mgr_; std::unique_ptr impl_; + Future<> finished_; + std::vector> _state; + std::mutex _gate; + AsofJoinNodeOptions _options; + + // Queue for triggering processing of a given input + // (a false value is a poison pill) + concurrent_bounded_queue _process; + + // Process thread + std::thread _process_thread; + + // Total batches produced, once we've finished -- only known at completion time. + util::optional _total_batches_produced; + + // In-progress batches produced + int _progress_batches_produced=0; }; std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector& inputs, @@ -84,7 +700,11 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector> fields; assert(inputs.size() > 1); - std::vector keys(options.keys.size()); + // TODO: Deal with multi keys + // std::vector keys; + // for (auto f: options.keys) { + // keys.emplace_back(*f.name()); + // } // Directly map LHS fields for(int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) @@ -95,7 +715,7 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vectoroutput_schema(); for(int i = 0; i < input_schema->num_fields(); ++i) { const auto &name = input_schema->field(i)->name(); - if((std::find(keys.begin(), keys.end(), name) != keys.end()) && (name!= *options.time.name())) { + if((name!= *options.keys.name()) && (name!= *options.time.name())) { fields.push_back(input_schema->field(i)); } } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 0c8d27083b0..f20816e13f0 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -66,13 +66,13 @@ BatchesWithSchema GenerateBatchesFromString( void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r_batches, const FieldRef time, - const std::vector& keys) { + const FieldRef keys) { auto exec_ctx = arrow::internal::make_unique( default_memory_pool(), nullptr ); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options{time, keys}; + AsofJoinNodeOptions join_options(time, keys, 0); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back( @@ -85,7 +85,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) .AddToPlan(plan.get())); - // ASSERT_FNISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); } void RunNonEmptyTest(bool exact_matches) { @@ -108,16 +108,14 @@ void RunNonEmptyTest(bool exact_matches) { l_batches = GenerateBatchesFromString( l_schema, - {R"([["2020-01-01", 0, 1.0]])"} - + {R"([["2020-01-01", 1, 1.0]])"} ); r_batches = GenerateBatchesFromString( l_schema, - {R"([["2020-01-01", 0, 2.0]])"} - + {R"([["2020-01-01", 1, 2.0]])"} ); - CheckRunOutput(l_batches, r_batches, "time", /*keys=*/{{"key"}}); + CheckRunOutput(l_batches, r_batches, "time", "key"); } class AsofJoinTest : public testing::TestWithParam> {}; diff --git a/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h b/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h new file mode 100644 index 00000000000..1b14cdb6eeb --- /dev/null +++ b/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace arrow { + namespace compute { + +template +class concurrent_bounded_queue { + size_t _remaining; + std::vector _buffer; + mutable std::mutex _gate; + std::condition_variable _not_full; + std::condition_variable _not_empty; + + size_t _next_push=0; + size_t _next_pop=0; +public: + concurrent_bounded_queue(size_t capacity) : _remaining(capacity),_buffer(capacity) { + } + // Push new value to queue, waiting for capacity indefinitely. + void push(const T &t) { + std::unique_lock lock(_gate); + _not_full.wait(lock,[&]{return _remaining>0;}); + _buffer[_next_push++]=t; + _next_push%=_buffer.size(); + --_remaining; + _not_empty.notify_one(); + } + // Get oldest value from queue, or wait indefinitely for it. + T pop() { + std::unique_lock lock(_gate); + _not_empty.wait(lock,[&]{return _remaining<_buffer.size();}); + T r=_buffer[_next_pop++]; + _next_pop%=_buffer.size(); + ++_remaining; + _not_full.notify_one(); + return r; + } + // Try to pop the oldest value from the queue (or return nullopt if none) + util::optional try_pop() { + std::unique_lock lock(_gate); + if(_remaining==_buffer.size()) return util::nullopt; + T r=_buffer[_next_pop++]; + _next_pop%=_buffer.size(); + ++_remaining; + _not_full.notify_one(); + return r; + } + + // Test whether empty + bool empty()const { + std::unique_lock lock(_gate); + return _remaining==_buffer.size(); + } + + // Un-synchronized access to front + // For this to be "safe": + // 1) the caller logically guarantees that queue is not empty + // 2) pop/try_pop cannot be called concurrently with this + const T &unsync_front()const { + return _buffer[_next_pop]; + } +}; + + } // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 7c7ae930308..95e8953065e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -546,6 +546,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterAggregateNode(this); internal::RegisterSinkNode(this); internal::RegisterHashJoinNode(this); + internal::RegisterAsofJoinNode(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 26933e2d757..383414a0a45 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -364,13 +364,15 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef time, std::vector keys) : - time(std::move(time)), keys(std::move(keys)) {} + AsofJoinNodeOptions(FieldRef time, FieldRef keys, int64_t tolerance) : + time(std::move(time)), keys(std::move(keys)), _tolerance(tolerance) {} // time column FieldRef time; // keys used for the join. All tables must have the same join key. - std::vector keys; + FieldRef keys; + + int64_t _tolerance; }; /// \brief Make a node which select top_k/bottom_k rows passed through it From 138daee2cca3f8b1ed57b3a7cce4ab7067a40b5c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 26 Apr 2022 09:05:54 -0400 Subject: [PATCH 04/47] wip --- cpp/src/arrow/compute/exec/asof_join_node.cc | 146 +++++++++++++----- .../arrow/compute/exec/asof_join_node_test.cc | 39 ++++- 2 files changed, 142 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 321e72647e4..83091badbb5 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -523,7 +523,9 @@ class AsofJoinNode : public ExecNode { } void process() { - // std::lock_guard guard(_gate); + std::cerr << "process() begin\n"; + + std::lock_guard guard(_gate); if(finished_.is_finished()) { std::cerr << "InputReceived EARLYEND\n"; return; } // Process batches while we have data @@ -535,6 +537,8 @@ class AsofJoinNode : public ExecNode { outputs_[0]->InputReceived(this,std::move(out_b)); } + std::cerr << "process() end\n"; + // Report to the output the total batch count, if we've already finished everything // (there are two places where this can happen: here and InputFinished) // @@ -543,27 +547,42 @@ class AsofJoinNode : public ExecNode { if(_state.at(0)->finished()) { //cerr << "LHS is finished\n"; _total_batches_produced=util::make_optional(_progress_batches_produced); + std::cerr << "process() finished " << *_total_batches_produced << "\n"; StopProducing(); assert(_total_batches_produced.has_value()); outputs_[0]->InputFinished(this,*_total_batches_produced); } } - void process_thread() { + Status process_thread(size_t /*thread_index*/, int64_t /*task_id*/) { std::cerr << "AsOfJoinNode::process_thread started.\n"; - for(;;) { - if(!_process.pop()) { + auto result = _process.try_pop(); + + if (result == util::nullopt) { + std::cerr << "AsOfJoinNode::process_thread no inputs.\n"; + return Status::OK(); + } else { + if (result.value()) { + std::cerr << "AsOfJoinNode::process_thread process.\n"; + process(); + } else { std::cerr << "AsOfJoinNode::process_thread done.\n"; - return; + return Status::OK(); } - process(); } + + return Status::OK(); } - static void process_thread_wrapper(AsofJoinNode *node) { - node->process_thread(); + Status process_finished(size_t /*thread_index*/) { + std::cerr << "AsOfJoinNode::process_finished started.\n"; + return Status::OK(); } + // static void process_thread_wrapper(AsofJoinNode *node) { + // node->process_thread(); + // } + public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, @@ -571,28 +590,7 @@ class AsofJoinNode : public ExecNode { std::shared_ptr output_schema, std::unique_ptr schema_mgr, std::unique_ptr impl - ) - : ExecNode(plan, inputs, {"left", "right"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), - impl_(std::move(impl)), - _options(join_options), - _process(1), - _process_thread(&AsOfJoinNode::process_thread_wrapper, this) - { - std::cout << "AsofJoinNode created" << "\n"; - - for(size_t i=0;i(inputs[i]->output_schema(), - *_options.time.name(), - *_options.keys.name(), - util::make_optional(0) /*TODO: make wildcard configuirable*/)); - size_t dst_offset=0; - for(auto &state:_state) - dst_offset=state->init_src_to_dst_mapping(dst_offset,!!dst_offset); - - finished_ = Future<>::MakeFinished(); - } + ); static arrow::Result Make(ExecPlan *plan, std::vector inputs, const ExecNodeOptions &options) { @@ -631,13 +629,15 @@ class AsofJoinNode : public ExecNode { } void InputFinished(ExecNode* input, int total_batches) override { std::cerr << "InputFinished BEGIN\n"; - bool is_finished=false; + // bool is_finished=false; { + std::lock_guard guard(_gate); + std::cerr << "InputFinished find\n"; ARROW_DCHECK(std::find(inputs_.begin(),inputs_.end(),input)!=inputs_.end()); size_t k=std::find(inputs_.begin(),inputs_.end(),input)-inputs_.begin(); //cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; _state.at(k)->set_total_batches(total_batches); - is_finished=_state.at(k)->finished(); + // is_finished=_state.at(k)->finished(); } // Trigger a process call // The reason for this is that there are cases at the end of a table where we don't @@ -650,6 +650,22 @@ class AsofJoinNode : public ExecNode { Status StartProducing() override { finished_=arrow::Future<>::Make(); std::cout << "StartProducing" << "\n"; + bool use_sync_execution = !(plan_->exec_context()->executor()); + std::cerr << "StartScheduling\n"; + std::cerr << "use_sync_execution: " << use_sync_execution << std::endl; + RETURN_NOT_OK( + scheduler_->StartScheduling(0 /*thread index*/, + std::move([this](std::function func) -> Status { + return this->ScheduleTaskCallback(std::move(func)); + }), + 1, + use_sync_execution + ) + ); + RETURN_NOT_OK( + scheduler_->StartTaskGroup(0, task_group_process_, 1) + ); + std::cerr << "StartScheduling done\n"; return Status::OK(); } void PauseProducing(ExecNode* output) override { @@ -673,6 +689,25 @@ class AsofJoinNode : public ExecNode { return finished_; } + Status ScheduleTaskCallback(std::function func) { + auto executor = plan_->exec_context()->executor(); + if (executor) { + RETURN_NOT_OK(executor->Spawn([this, func] { + size_t thread_index = thread_indexer_(); + Status status = func(thread_index); + if (!status.ok()) { + StopProducing(); + ErrorIfNotOk(status); + return; + } + })); + } else { + // We should not get here in serial execution mode + ARROW_DCHECK(false); + } + return Status::OK(); + } + private: std::unique_ptr schema_mgr_; std::unique_ptr impl_; @@ -681,13 +716,13 @@ class AsofJoinNode : public ExecNode { std::mutex _gate; AsofJoinNodeOptions _options; + ThreadIndexer thread_indexer_; + std::unique_ptr scheduler_; + int task_group_process_; // Queue for triggering processing of a given input // (a false value is a poison pill) concurrent_bounded_queue _process; - // Process thread - std::thread _process_thread; - // Total batches produced, once we've finished -- only known at completion time. util::optional _total_batches_produced; @@ -725,6 +760,47 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector(fields); } + +AsofJoinNode::AsofJoinNode(ExecPlan* plan, + NodeVector inputs, + const AsofJoinNodeOptions& join_options, + std::shared_ptr output_schema, + std::unique_ptr schema_mgr, + std::unique_ptr impl + ) + : ExecNode(plan, inputs, {"left", "right"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + impl_(std::move(impl)), + _options(join_options), + _process(1) + // _process_thread(&AsofJoinNode::process_thread_wrapper, this) +{ + std::cout << "AsofJoinNode created" << "\n"; + + for(size_t i=0;i(inputs[i]->output_schema(), + *_options.time.name(), + *_options.keys.name(), + util::make_optional(0) /*TODO: make wildcard configuirable*/)); + size_t dst_offset=0; + for(auto &state:_state) + dst_offset=state->init_src_to_dst_mapping(dst_offset,!!dst_offset); + + finished_ = Future<>::MakeFinished(); + + scheduler_ = TaskScheduler::Make(); + task_group_process_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return process_thread(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { + return process_finished(thread_index); + } + ); + scheduler_->RegisterEnd(); +} + namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index f20816e13f0..51d90355f9e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -65,6 +65,7 @@ BatchesWithSchema GenerateBatchesFromString( void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r_batches, + const BatchesWithSchema& exp_batches, const FieldRef time, const FieldRef keys) { auto exec_ctx = arrow::internal::make_unique( @@ -86,36 +87,58 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, .AddToPlan(plan.get())); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + ASSERT_OK_AND_ASSIGN(auto exp_table, + TableFromExecBatches(exp_batches.schema, exp_batches.batches)); + + ASSERT_OK_AND_ASSIGN(auto res_table, + TableFromExecBatches(exp_batches.schema, res)); + + + AssertTablesEqual(*exp_table, *res_table, + /*same_chunk_layout=*/false, /*flatten=*/true); + + std::cerr << "Result Equals" << "\n"; } void RunNonEmptyTest(bool exact_matches) { auto l_schema = schema( { - field("time", timestamp(TimeUnit::NANO)), + field("time", int64()), field("key", int32()), - field("l_v0", float32()) + field("l_v0", float64()) } ); auto r_schema = schema( { - field("time", timestamp(TimeUnit::NANO)), + field("time", int64()), field("key", int32()), - field("r_v0", float32()) + field("r_v0", float64()) } ); + auto exp_schema = schema({ + field("time", int64()), + field("key", int32()), + field("l_v0", float64()), + field("r_v0", float64()), + }); BatchesWithSchema l_batches, r_batches, exp_batches; l_batches = GenerateBatchesFromString( l_schema, - {R"([["2020-01-01", 1, 1.0]])"} + {R"([[0, 1, 1.0]])"} ); r_batches = GenerateBatchesFromString( - l_schema, - {R"([["2020-01-01", 1, 2.0]])"} + r_schema, + {R"([[0, 1, 2.0]])"} ); + exp_batches = GenerateBatchesFromString( + exp_schema, + {R"([[0, 1, 1.0, 2.0]])"} + ); - CheckRunOutput(l_batches, r_batches, "time", "key"); + CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key"); } class AsofJoinTest : public testing::TestWithParam> {}; From 94a8453ecbfe3ac9e9860f858b5e7623a9af3daa Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 26 Apr 2022 15:05:27 -0400 Subject: [PATCH 05/47] wip: First test pass --- cpp/src/arrow/compute/exec/asof_join_node.cc | 179 ++++++++++--------- 1 file changed, 99 insertions(+), 80 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 83091badbb5..4397acb4ad8 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -554,35 +554,48 @@ class AsofJoinNode : public ExecNode { } } - Status process_thread(size_t /*thread_index*/, int64_t /*task_id*/) { - std::cerr << "AsOfJoinNode::process_thread started.\n"; - auto result = _process.try_pop(); - - if (result == util::nullopt) { - std::cerr << "AsOfJoinNode::process_thread no inputs.\n"; - return Status::OK(); - } else { - if (result.value()) { - std::cerr << "AsOfJoinNode::process_thread process.\n"; - process(); - } else { - std::cerr << "AsOfJoinNode::process_thread done.\n"; - return Status::OK(); + // Status process_thread(size_t /*thread_index*/, int64_t /*task_id*/) { + // std::cerr << "AsOfJoinNode::process_thread started.\n"; + // auto result = _process.try_pop(); + + // if (result == util::nullopt) { + // std::cerr << "AsOfJoinNode::process_thread no inputs.\n"; + // return Status::OK(); + // } else { + // if (result.value()) { + // std::cerr << "AsOfJoinNode::process_thread process.\n"; + // process(); + // } else { + // std::cerr << "AsOfJoinNode::process_thread done.\n"; + // return Status::OK(); + // } + // } + + // return Status::OK(); + // } + + // Status process_finished(size_t /*thread_index*/) { + // std::cerr << "AsOfJoinNode::process_finished started.\n"; + // return Status::OK(); + // } + + void process_thread() { + std::cerr << "AsOfMergeNode::process_thread started.\n"; + for(;;) { + if(!_process.pop()) { + std::cerr << "AsOfMergeNode::process_thread done.\n"; + return; } + //cerr << "AsOfMergeNode::process() BEGIN\n"; + process(); + //cerr << "AsOfMergeNode::process() END\n"; } - - return Status::OK(); } - - Status process_finished(size_t /*thread_index*/) { - std::cerr << "AsOfJoinNode::process_finished started.\n"; - return Status::OK(); + + static void process_thread_wrapper(AsofJoinNode *node) { + node->process_thread(); } - // static void process_thread_wrapper(AsofJoinNode *node) { - // node->process_thread(); - // } - public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, @@ -592,6 +605,11 @@ class AsofJoinNode : public ExecNode { std::unique_ptr impl ); + virtual ~AsofJoinNode() { + _process.push(false); // poison pill + _process_thread.join(); + } + static arrow::Result Make(ExecPlan *plan, std::vector inputs, const ExecNodeOptions &options) { std::unique_ptr schema_mgr = @@ -648,24 +666,24 @@ class AsofJoinNode : public ExecNode { std::cerr << "InputFinished END\n"; } Status StartProducing() override { - finished_=arrow::Future<>::Make(); std::cout << "StartProducing" << "\n"; - bool use_sync_execution = !(plan_->exec_context()->executor()); - std::cerr << "StartScheduling\n"; - std::cerr << "use_sync_execution: " << use_sync_execution << std::endl; - RETURN_NOT_OK( - scheduler_->StartScheduling(0 /*thread index*/, - std::move([this](std::function func) -> Status { - return this->ScheduleTaskCallback(std::move(func)); - }), - 1, - use_sync_execution - ) - ); - RETURN_NOT_OK( - scheduler_->StartTaskGroup(0, task_group_process_, 1) - ); - std::cerr << "StartScheduling done\n"; + finished_=arrow::Future<>::Make(); + // bool use_sync_execution = !(plan_->exec_context()->executor()); + // std::cerr << "StartScheduling\n"; + // std::cerr << "use_sync_execution: " << use_sync_execution << std::endl; + // RETURN_NOT_OK( + // scheduler_->StartScheduling(0 /*thread index*/, + // std::move([this](std::function func) -> Status { + // return this->ScheduleTaskCallback(std::move(func)); + // }), + // 1, + // use_sync_execution + // ) + // ); + // RETURN_NOT_OK( + // scheduler_->StartTaskGroup(0, task_group_process_, 1) + // ); + // std::cerr << "StartScheduling done\n"; return Status::OK(); } void PauseProducing(ExecNode* output) override { @@ -680,33 +698,31 @@ class AsofJoinNode : public ExecNode { std::cout << "StopProducing" << "\n"; } void StopProducing() override { + std::cerr << "StopProducing" << std::endl; //if(batch_count_.Cancel()) finished_.MarkFinished(); - finished_.MarkFinished(); + finished_.MarkFinished(); for(auto&& input: inputs_) input->StopProducing(this); } - Future<> finished() override { - std::cout << "finished" << "\n"; - return finished_; - } - - Status ScheduleTaskCallback(std::function func) { - auto executor = plan_->exec_context()->executor(); - if (executor) { - RETURN_NOT_OK(executor->Spawn([this, func] { - size_t thread_index = thread_indexer_(); - Status status = func(thread_index); - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; - } - })); - } else { - // We should not get here in serial execution mode - ARROW_DCHECK(false); - } - return Status::OK(); - } + Future<> finished() override { return finished_; } + + // Status ScheduleTaskCallback(std::function func) { + // auto executor = plan_->exec_context()->executor(); + // if (executor) { + // RETURN_NOT_OK(executor->Spawn([this, func] { + // size_t thread_index = thread_indexer_(); + // Status status = func(thread_index); + // if (!status.ok()) { + // StopProducing(); + // ErrorIfNotOk(status); + // return; + // } + // })); + // } else { + // // We should not get here in serial execution mode + // ARROW_DCHECK(false); + // } + // return Status::OK(); + // } private: std::unique_ptr schema_mgr_; @@ -716,12 +732,15 @@ class AsofJoinNode : public ExecNode { std::mutex _gate; AsofJoinNodeOptions _options; - ThreadIndexer thread_indexer_; - std::unique_ptr scheduler_; - int task_group_process_; + //ThreadIndexer thread_indexer_; + //std::unique_ptr scheduler_; + // int task_group_process_; + // Queue for triggering processing of a given input // (a false value is a poison pill) concurrent_bounded_queue _process; + // Worker thread + std::thread _process_thread; // Total batches produced, once we've finished -- only known at completion time. util::optional _total_batches_produced; @@ -773,8 +792,8 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, /*num_outputs=*/1), impl_(std::move(impl)), _options(join_options), - _process(1) - // _process_thread(&AsofJoinNode::process_thread_wrapper, this) + _process(1), + _process_thread(&AsofJoinNode::process_thread_wrapper, this) { std::cout << "AsofJoinNode created" << "\n"; @@ -789,16 +808,16 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, finished_ = Future<>::MakeFinished(); - scheduler_ = TaskScheduler::Make(); - task_group_process_ = scheduler_->RegisterTaskGroup( - [this](size_t thread_index, int64_t task_id) -> Status { - return process_thread(thread_index, task_id); - }, - [this](size_t thread_index) -> Status { - return process_finished(thread_index); - } - ); - scheduler_->RegisterEnd(); + // scheduler_ = TaskScheduler::Make(); + // task_group_process_ = scheduler_->RegisterTaskGroup( + // [this](size_t thread_index, int64_t task_id) -> Status { + // return process_thread(thread_index, task_id); + // }, + // [this](size_t thread_index) -> Status { + // return process_finished(thread_index); + // } + // ); + // scheduler_->RegisterEnd(); } namespace internal { From 6466b80342fe915aeb3609d94675dea74e6e7893 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Apr 2022 16:18:20 -0400 Subject: [PATCH 06/47] Fix code style and lint (partial) --- cpp/src/arrow/compute/exec/asof_join.cc | 23 +- cpp/src/arrow/compute/exec/asof_join.h | 15 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 910 +++++++++--------- .../arrow/compute/exec/asof_join_node_test.cc | 97 +- .../compute/exec/concurrent_bounded_queue.h | 125 +-- cpp/src/arrow/compute/exec/options.h | 5 +- 6 files changed, 601 insertions(+), 574 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join.cc b/cpp/src/arrow/compute/exec/asof_join.cc index d70bce213be..d214ef3078e 100644 --- a/cpp/src/arrow/compute/exec/asof_join.cc +++ b/cpp/src/arrow/compute/exec/asof_join.cc @@ -24,22 +24,22 @@ #include #include //#include -#include -#include #include #include #include +#include +#include #include -#include #include -#include // so we don't need to require C++20 +#include // so we don't need to require C++20 #include -#include +#include #include -#include -#include -#include #include +#include +#include +#include +#include #include @@ -48,12 +48,7 @@ namespace arrow { namespace compute { - - - -class AsofJoinBasicImpl : public AsofJoinImpl { - -}; +class AsofJoinBasicImpl : public AsofJoinImpl {}; Result> AsofJoinImpl::MakeBasic() { std::unique_ptr impl{new AsofJoinBasicImpl()}; diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index b23c8aef1ec..ee5492bcc18 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -23,6 +23,10 @@ #include #include +#include +#include +#include // so we don't need to require C++20 +#include #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/task_util.h" @@ -30,10 +34,6 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/tracing_internal.h" -#include -#include -#include -#include // so we don't need to require C++20 #include "concurrent_bounded_queue.h" @@ -43,7 +43,7 @@ namespace compute { typedef int32_t KeyType; // Maximum number of tables that can be joined -#define MAX_JOIN_TABLES 6 +#define MAX_JOIN_TABLES 64 // Capacity of the input queues (for flow control) // Why 2? @@ -61,15 +61,12 @@ typedef int col_index_t; class AsofJoinSchema { public: std::shared_ptr MakeOutputSchema(const std::vector& inputs, - const AsofJoinNodeOptions& options - ); - + const AsofJoinNodeOptions& options); }; class AsofJoinImpl { public: static Result> MakeBasic(); - }; } // namespace compute diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 4397acb4ad8..2a57c0c1f47 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -18,8 +18,8 @@ #include #include -#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/asof_join.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" @@ -34,247 +34,265 @@ #include #include //#include -#include -#include #include #include #include +#include +#include #include -#include #include -#include // so we don't need to require C++20 +#include // so we don't need to require C++20 #include -#include +#include #include -#include -#include -#include #include +#include +#include +#include +#include #include - #include "concurrent_bounded_queue.h" namespace arrow { namespace compute { struct MemoStore { + // Stores last known values for all the keys + struct Entry { - // Timestamp associated with the entry + // Timestamp associated with the entry int64_t _time; - // Batch associated with the entry (perf is probably OK for this; batches change rarely) + // Batch associated with the entry (perf is probably OK for this; batches change + // rarely) std::shared_ptr _batch; // Row associated with the entry row_index_t _row; }; - std::unordered_map _entries; + std::unordered_map _entries; - void store(const std::shared_ptr &batch, row_index_t row, int64_t time, KeyType key) { - auto &e=_entries[key]; + void store(const std::shared_ptr& batch, row_index_t row, int64_t time, + KeyType key) { + auto& e = _entries[key]; // that we can do this assignment optionally, is why we // can get array with using shared_ptr above (the batch // shouldn't change that often) - if(e._batch!=batch) e._batch=batch; - e._row=row; - e._time=time; + if (e._batch != batch) e._batch = batch; + e._row = row; + e._time = time; } util::optional get_entry_for_key(KeyType key) const { - auto e=_entries.find(key); - if(_entries.end()==e) - return util::nullopt; + auto e = _entries.find(key); + if (_entries.end() == e) return util::nullopt; return util::optional(&e->second); } void remove_entries_with_lesser_time(int64_t ts) { - size_t dbg_size0=_entries.size(); - for(auto e=_entries.begin();e!=_entries.end();) - if(e->second._timesecond._time < ts) + e = _entries.erase(e); + else + ++e; + size_t dbg_size1 = _entries.size(); + if (dbg_size1 < dbg_size0) { + // cerr << "Removed " << dbg_size0-dbg_size1 << " memo entries.\n"; + } } }; -class InputState -{ - // Pending record batches. The latest is the front. Batches cannot be empty. - concurrent_bounded_queue > _queue; - - // Wildcard key for this input, if applicable. - util::optional _wildcard_key; - - // Schema associated with the input - std::shared_ptr _schema; - - // Total number of batches (only int because InputFinished uses int) - int _total_batches=-1; - - // Number of batches processed so far (only int because InputFinished uses int) - int _batches_processed=0; - - // Index of the time col - col_index_t _time_col_index; - - // Index of the key col - col_index_t _key_col_index; - - // Index of the latest row reference within; if >0 then _queue cannot be empty - row_index_t _latest_ref_row=0; // must be < _queue.front()->num_rows() if _queue is non-empty - - // Stores latest known values for the various keys - MemoStore _memo; - - // Mapping of source columns to destination columns - std::vector> src_to_dst; - -public: - InputState(const std::shared_ptr &schema, - const std::string &time_col_name, - const std::string &key_col_name, +class InputState { + // InputState correponds to an input + // Input record batches are queued up in InputState until processed and + // turned into output record batches. + + public: + InputState(const std::shared_ptr& schema, + const std::string& time_col_name, const std::string& key_col_name, util::optional wildcard_key) - : - _queue(QUEUE_CAPACITY), - _wildcard_key(wildcard_key), - _schema(schema), - _time_col_index(schema->GetFieldIndex(time_col_name)), //TODO: handle missing field name - _key_col_index(schema->GetFieldIndex(key_col_name)) //TODO: handle missing field name - { /*nothing else*/ } - - size_t init_src_to_dst_mapping(size_t dst_offset,bool skip_time_and_key_fields) { - src_to_dst.resize(_schema->num_fields()); - for(int i=0;i<_schema->num_fields();++i) - if(!(skip_time_and_key_fields && is_time_or_key_column(i))) - src_to_dst[i]=dst_offset++; - return dst_offset; - } + : _queue(QUEUE_CAPACITY), + _wildcard_key(wildcard_key), + _schema(schema), + _time_col_index( + schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name + _key_col_index( + schema->GetFieldIndex(key_col_name)) // TODO: handle missing field name + { /*nothing else*/ + } - const util::optional& map_src_to_dst(col_index_t src)const { - return src_to_dst[src]; - } + size_t init_src_to_dst_mapping(size_t dst_offset, bool skip_time_and_key_fields) { + src_to_dst.resize(_schema->num_fields()); + for (int i = 0; i < _schema->num_fields(); ++i) + if (!(skip_time_and_key_fields && is_time_or_key_column(i))) + src_to_dst[i] = dst_offset++; + return dst_offset; + } + + const util::optional& map_src_to_dst(col_index_t src) const { + return src_to_dst[src]; + } bool is_time_or_key_column(col_index_t i) const { - assert(i<_schema->num_fields()); - return (i==_time_col_index) || (i==_key_col_index); + assert(i < _schema->num_fields()); + return (i == _time_col_index) || (i == _key_col_index); } // Gets the latest row index, assuming the queue isn't empty - row_index_t get_latest_row() const { - return _latest_ref_row; - } + row_index_t get_latest_row() const { return _latest_ref_row; } bool empty() const { - if(_latest_ref_row>0) return false; // cannot be empty if ref row is >0 -- can avoid slow queue lock below - return _queue.empty(); + if (_latest_ref_row > 0) + return false; // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below + return _queue.empty(); } - int count_batches_processed()const { return _batches_processed; } - int count_total_batches()const { return _total_batches; } + int count_batches_processed() const { return _batches_processed; } + int count_total_batches() const { return _total_batches; } - // Gets latest batch (precondition: must not be empty) - const std::shared_ptr &get_latest_batch()const { - return _queue.unsync_front(); - } + // Gets latest batch (precondition: must not be empty) + const std::shared_ptr& get_latest_batch() const { + return _queue.unsync_front(); + } KeyType get_latest_key() const { - return _queue.unsync_front()->column_data(_key_col_index)->GetValues(1)[_latest_ref_row]; + return _queue.unsync_front() + ->column_data(_key_col_index) + ->GetValues(1)[_latest_ref_row]; } - int64_t get_latest_time()const { - return _queue.unsync_front()->column_data(_time_col_index)->GetValues(1)[_latest_ref_row]; + int64_t get_latest_time() const { + return _queue.unsync_front() + ->column_data(_time_col_index) + ->GetValues(1)[_latest_ref_row]; } - bool finished() const { - return _batches_processed==_total_batches; - } - - bool advance() { - // Returns true if able to advance, false if not. - bool have_active_batch=(_latest_ref_row>0/*short circuit the lock on the queue*/) || !_queue.empty(); - if(have_active_batch) { - // If we have an active batch - if(++_latest_ref_row>=_queue.unsync_front()->num_rows()) { - // hit the end of the batch, need to get the next batch if possible. - ++_batches_processed; - _latest_ref_row=0; - have_active_batch&=!_queue.try_pop(); - if(have_active_batch) assert(_queue.unsync_front()->num_rows()>0); // empty batches disallowed - } - } - return have_active_batch; + bool finished() const { return _batches_processed == _total_batches; } + + bool advance() { + // Returns true if able to advance, false if not. + + bool have_active_batch = + (_latest_ref_row > 0 /*short circuit the lock on the queue*/) || !_queue.empty(); + if (have_active_batch) { + // If we have an active batch + if (++_latest_ref_row >= _queue.unsync_front()->num_rows()) { + // hit the end of the batch, need to get the next batch if possible. + ++_batches_processed; + _latest_ref_row = 0; + have_active_batch &= !_queue.try_pop(); + if (have_active_batch) + assert(_queue.unsync_front()->num_rows() > 0); // empty batches disallowed + } } + return have_active_batch; + } - // Advance the data to be immediately past the specified TS, updating latest and latest_ref_row to - // the latest data prior to that immediate just past - // Returns true if updates were made, false if not. - bool advance_and_memoize(int64_t ts) { - // Check if already updated for TS (or if there is no latest) - if(empty()) return false; // can't advance if empty - auto latest_time=get_latest_time(); - if(latest_time>ts) return false; // already advanced - - // Not updated. Try to update and possibly advance. - bool updated=false; - do { - latest_time=get_latest_time(); - if(latest_time<=ts) // if advance() returns true, then the latest_ts must also be valid - _memo.store(get_latest_batch(),_latest_ref_row,latest_time,get_latest_key()); - else - break; // hit a future timestamp -- done updating for now - updated=true; - } while (advance()); - return updated; - } + // Advance the data to be immediately past the specified TS, updating latest and + // latest_ref_row to the latest data prior to that immediate just past Returns true if + // updates were made, false if not. + bool advance_and_memoize(int64_t ts) { + // Advance the right side row index until we reach the latest right row (for each key) + // for the given left timestamp. + + // Check if already updated for TS (or if there is no latest) + if (empty()) return false; // can't advance if empty + auto latest_time = get_latest_time(); + if (latest_time > ts) return false; // already advanced + + // Not updated. Try to update and possibly advance. + bool updated = false; + do { + latest_time = get_latest_time(); + if (latest_time <= + ts) // if advance() returns true, then the latest_ts must also be valid + // Keep advancing right table until we hit the latest row that has + // timestamp <= ts. This is because we only need the latest row for the + // match given a left ts. However, for a futre + _memo.store(get_latest_batch(), _latest_ref_row, latest_time, get_latest_key()); + else + break; // hit a future timestamp -- done updating for now + updated = true; + } while (advance()); + return updated; + } - void push(const std::shared_ptr &rb) { - if(rb->num_rows()>0) { - _queue.push(rb); - } else { - ++_batches_processed; // don't enqueue empty batches, just record as processed - } + void push(const std::shared_ptr& rb) { + if (rb->num_rows() > 0) { + _queue.push(rb); + } else { + ++_batches_processed; // don't enqueue empty batches, just record as processed } + } util::optional get_memo_entry_for_key(KeyType key) { - auto r=_memo.get_entry_for_key(key); - if(r.has_value()) return r; - if(_wildcard_key.has_value()) - r=_memo.get_entry_for_key(*_wildcard_key); + auto r = _memo.get_entry_for_key(key); + if (r.has_value()) return r; + if (_wildcard_key.has_value()) r = _memo.get_entry_for_key(*_wildcard_key); return r; } util::optional get_memo_time_for_key(KeyType key) { - auto r=get_memo_entry_for_key(key); - return r.has_value()?util::make_optional((*r)->_time):util::nullopt; + auto r = get_memo_entry_for_key(key); + return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; } - void remove_memo_entries_with_lesser_time(int64_t ts) { - _memo.remove_entries_with_lesser_time(ts); - } + void remove_memo_entries_with_lesser_time(int64_t ts) { + _memo.remove_entries_with_lesser_time(ts); + } - const std::shared_ptr &get_schema() const { return _schema; } + const std::shared_ptr& get_schema() const { return _schema; } - void set_total_batches(int n) { - assert(n>=0); // not sure why arrow uses a signed int for this, but it should be >=0 - assert(_total_batches==-1); // shouldn't be set more than once - _total_batches=n; - } + void set_total_batches(int n) { + assert(n >= + 0); // not sure why arrow uses a signed int for this, but it should be >=0 + assert(_total_batches == -1); // shouldn't be set more than once + _total_batches = n; + } + + private: + // Pending record batches. The latest is the front. Batches cannot be empty. + concurrent_bounded_queue> _queue; + + // Wildcard key for this input, if applicable. + util::optional _wildcard_key; + + // Schema associated with the input + std::shared_ptr _schema; + + // Total number of batches (only int because InputFinished uses int) + int _total_batches = -1; + + // Number of batches processed so far (only int because InputFinished uses int) + int _batches_processed = 0; + + // Index of the time col + col_index_t _time_col_index; + + // Index of the key col + col_index_t _key_col_index; + + // Index of the latest row reference within; if >0 then _queue cannot be empty + row_index_t _latest_ref_row = + 0; // must be < _queue.front()->num_rows() if _queue is non-empty + + // Stores latest known values for the various keys + MemoStore _memo; + + // Mapping of source columns to destination columns + std::vector> src_to_dst; }; template -struct CompositeReferenceRow -{ - struct Entry - { - arrow::RecordBatch *_batch; // can be NULL if there's no value - row_index_t _row; - }; - Entry _refs[MAX_TABLES]; +struct CompositeReferenceRow { + struct Entry { + arrow::RecordBatch* _batch; // can be NULL if there's no value + row_index_t _row; + }; + Entry _refs[MAX_TABLES]; }; // A table of composite reference rows. Rows maintain pointers to the @@ -289,236 +307,252 @@ struct CompositeReferenceRow // We don't put the shared_ptr's into the rows for efficiency reasons. template class CompositeReferenceTable { - // Contains shared_ptr refs for all RecordBatches referred to by the contents of _rows - std::unordered_map> _ptr2ref; + // Contains shared_ptr refs for all RecordBatches referred to by the contents of _rows + std::unordered_map> _ptr2ref; - // Row table references + // Row table references std::vector> _rows; - // Total number of tables in the composite table + // Total number of tables in the composite table size_t _n_tables; - // Adds a RecordBatch ref to the mapping, if needed - void add_record_batch_ref(const std::shared_ptr &ref) { - if(!_ptr2ref.count((uintptr_t)ref.get())) - _ptr2ref[(uintptr_t)ref.get()]=ref; - } -public: - CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) - { - assert(_n_tables>=1); - assert(_n_tables<=MAX_TABLES); - } + // Adds a RecordBatch ref to the mapping, if needed + void add_record_batch_ref(const std::shared_ptr& ref) { + if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; + } - size_t n_rows() const { return _rows.size(); } + public: + CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) { + assert(_n_tables >= 1); + assert(_n_tables <= MAX_TABLES); + } - // Adds the latest row from the input state as a new composite reference row - // - LHS must have a valid key,timestep,and latest rows - // - RHS must have valid data memo'ed for the key - void emplace(std::vector> &in,int64_t tolerance) { - assert(in.size()==_n_tables); + size_t n_rows() const { return _rows.size(); } + + // Adds the latest row from the input state as a new composite reference row + // - LHS must have a valid key,timestep,and latest rows + // - RHS must have valid data memo'ed for the key + void emplace(std::vector>& in, int64_t tolerance) { + assert(in.size() == _n_tables); // Get the LHS key - KeyType key=in[0]->get_latest_key(); + KeyType key = in[0]->get_latest_key(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) assert(!in[0]->empty()); - const std::shared_ptr &lhs_latest_batch=in[0]->get_latest_batch(); - row_index_t lhs_latest_row=in[0]->get_latest_row(); - int64_t lhs_latest_time=in[0]->get_latest_time(); - if(0==lhs_latest_row) { + const std::shared_ptr& lhs_latest_batch = + in[0]->get_latest_batch(); + row_index_t lhs_latest_row = in[0]->get_latest_row(); + int64_t lhs_latest_time = in[0]->get_latest_time(); + if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. - assert(lhs_latest_batch->num_rows()<=MAX_ROWS_PER_BATCH); //TODO: better error handling - row_index_t new_batch_size=lhs_latest_batch->num_rows(); - row_index_t new_capacity=_rows.size()+new_batch_size; - if(_rows.capacity() + std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { + Builder builder; + // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask + builder.AppendEmptyValues(_rows.size()); + for (row_index_t i_row = 0; i_row < _rows.size(); ++i_row) { + const auto& ref = _rows[i_row]._refs[i_table]; + if (ref._batch) + builder[i_row] = + ref._batch->column_data(i_col)->template GetValues( + 1)[ref._row]; + // TODO: set null value if ref._batch is null -- currently we don't due to API + // limitations of the builders. + } + std::shared_ptr result; + if (!builder.Finish(&result).ok()) { + std::cerr << "Error when creating Arrow array from builder\n"; + exit(-1); // TODO: better error handling + } + return result; + } + + public: + // Materializes the current reference table into a target record batch + std::shared_ptr materialize( + const std::shared_ptr& output_schema, + const std::vector>& state) { + // cerr << "materialize BEGIN\n"; + assert(state.size() == _n_tables); + assert(state.size() >= 1); + + // Don't build empty batches + size_t n_rows = _rows.size(); + if (!n_rows) return nullptr; + + // Count output columns (dbg sanitycheck) + { + int n_out_cols = 0; + for (const auto& s : state) n_out_cols += s->get_schema()->num_fields(); + n_out_cols -= + (state.size() - 1) * 2; // remove column indices for key and time cols on RHS + assert(n_out_cols == output_schema->num_fields()); } -public: - // Materializes the current reference table into a target record batch - std::shared_ptr materialize(const std::shared_ptr &output_schema, - const std::vector> &state) { - //cerr << "materialize BEGIN\n"; - assert(state.size()==_n_tables); - assert(state.size()>=1); - - // Don't build empty batches - size_t n_rows=_rows.size(); - if(!n_rows) return nullptr; - - // Count output columns (dbg sanitycheck) - { - int n_out_cols=0; - for(const auto &s:state) - n_out_cols+=s->get_schema()->num_fields(); - n_out_cols-=(state.size()-1)*2; // remove column indices for key and time cols on RHS - assert(n_out_cols==output_schema->num_fields()); - } - // Instance the types we support - std::shared_ptr i32_type=arrow::int32(); - std::shared_ptr i64_type=arrow::int64(); - std::shared_ptr f64_type=arrow::float64(); - - // Build the arrays column-by-column from our rows - std::vector> arrays(output_schema->num_fields()); - for(size_t i_table=0;i_table<_n_tables;++i_table) { - int n_src_cols=state.at(i_table)->get_schema()->num_fields(); - { - for(col_index_t i_src_col=0;i_src_col i_dst_col_opt=state[i_table]->map_src_to_dst(i_src_col); - if(!i_dst_col_opt) continue; - col_index_t i_dst_col=*i_dst_col_opt; - const auto &src_field=state[i_table]->get_schema()->field(i_src_col); - const auto &dst_field=output_schema->field(i_dst_col); - assert(src_field->type()->Equals(dst_field->type())); - assert(src_field->name()==dst_field->name()); - const auto &field_type=src_field->type(); - if(field_type->Equals(i32_type)) { - arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); - } else if(field_type->Equals(i64_type)) { - arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); - } else if(field_type->Equals(f64_type)) { - arrays.at(i_dst_col)=materialize_primitive_column(i_table,i_src_col); - } else { - std::cerr << "Unsupported data type: " << field_type->name() << "\n"; - exit(-1); // TODO: validate elsewhere for better error handling - } - } + // Instance the types we support + std::shared_ptr i32_type = arrow::int32(); + std::shared_ptr i64_type = arrow::int64(); + std::shared_ptr f64_type = arrow::float64(); + + // Build the arrays column-by-column from our rows + std::vector> arrays(output_schema->num_fields()); + for (size_t i_table = 0; i_table < _n_tables; ++i_table) { + int n_src_cols = state.at(i_table)->get_schema()->num_fields(); + { + for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { + util::optional i_dst_col_opt = + state[i_table]->map_src_to_dst(i_src_col); + if (!i_dst_col_opt) continue; + col_index_t i_dst_col = *i_dst_col_opt; + const auto& src_field = state[i_table]->get_schema()->field(i_src_col); + const auto& dst_field = output_schema->field(i_dst_col); + assert(src_field->type()->Equals(dst_field->type())); + assert(src_field->name() == dst_field->name()); + const auto& field_type = src_field->type(); + if (field_type->Equals(i32_type)) { + arrays.at(i_dst_col) = + materialize_primitive_column(i_table, + i_src_col); + } else if (field_type->Equals(i64_type)) { + arrays.at(i_dst_col) = + materialize_primitive_column(i_table, + i_src_col); + } else if (field_type->Equals(f64_type)) { + arrays.at(i_dst_col) = + materialize_primitive_column(i_table, + i_src_col); + } else { + std::cerr << "Unsupported data type: " << field_type->name() << "\n"; + exit(-1); // TODO: validate elsewhere for better error handling } } - - // Build the result - assert(sizeof(size_t)>=sizeof(int64_t)); // Make takes signed int64_t for num_rows - //TODO: check n_rows for cast - std::shared_ptr r=arrow::RecordBatch::Make(output_schema,(int64_t)n_rows,arrays); - //cerr << "materialize END (ndstrows="<< (r?r->num_rows():-1) <<")\n"; - return r; + } } - // Returns true if there are no rows - bool empty() const { - return _rows.empty(); - } + // Build the result + assert(sizeof(size_t) >= sizeof(int64_t)); // Make takes signed int64_t for num_rows + // TODO: check n_rows for cast + std::shared_ptr r = + arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); + // cerr << "materialize END (ndstrows="<< (r?r->num_rows():-1) <<")\n"; + return r; + } + + // Returns true if there are no rows + bool empty() const { return _rows.empty(); } }; class AsofJoinNode : public ExecNode { - - // Constructs labels for inputs - static std::vector build_input_labels(const std::vector &inputs) { + // Constructs labels for inputs + static std::vector build_input_labels( + const std::vector& inputs) { std::vector r(inputs.size()); - for(size_t i=0;iadvance_and_memoize(lhs_latest_time); + auto& lhs = *_state.at(0); + auto lhs_latest_time = lhs.get_latest_time(); + bool any_updated = false; + for (size_t i = 1; i < _state.size(); ++i) + any_updated |= _state[i]->advance_and_memoize(lhs_latest_time); return any_updated; } // Returns false if RHS not up to date for LHS bool is_up_to_date_for_lhs_row() const { - auto &lhs=*_state[0]; - if(lhs.empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts=lhs.get_latest_time(); - for(size_t i=1;i<_state.size();++i) - { - auto &rhs=*_state[i]; - if(!rhs.finished()) { - // If RHS is finished, then we know it's up to date (but if it isn't, it might be up to date) - if(rhs.empty()) return false; // RHS isn't finished, but is empty --> not up to date - if(lhs_ts>=rhs.get_latest_time()) return false; // TS not up to date (and not finished) - } - } + auto& lhs = *_state[0]; + if (lhs.empty()) return false; // can't proceed if nothing on the LHS + int64_t lhs_ts = lhs.get_latest_time(); + for (size_t i = 1; i < _state.size(); ++i) { + auto& rhs = *_state[i]; + if (!rhs.finished()) { + // If RHS is finished, then we know it's up to date (but if it isn't, it might be + // up to date) + if (rhs.empty()) + return false; // RHS isn't finished, but is empty --> not up to date + if (lhs_ts >= rhs.get_latest_time()) + return false; // TS not up to date (and not finished) + } + } return true; } std::shared_ptr process_inner() { - assert(!_state.empty()); - auto &lhs=*_state.at(0); + auto& lhs = *_state.at(0); // Construct new target table if needed CompositeReferenceTable dst(_state.size()); - // Generate rows into the dst table until we either run out of data or hit the row limit, or run out of input - for(;;) { + // Generate rows into the dst table until we either run out of data or hit the row + // limit, or run out of input + for (;;) { // If LHS is finished or empty then there's nothing we can do here - if(lhs.finished() || lhs.empty()) break; + if (lhs.finished() || lhs.empty()) break; - // Advance each of the RHS as far as possible to be up to date for the LHS timestep - bool any_advanced=update_rhs(); + // Advance each of the RHS as far as possible to be up to date for the LHS timestamp + bool any_advanced = update_rhs(); - // Only update if we have up-to-date information for the LHS row - if(is_up_to_date_for_lhs_row()) { - dst.emplace(_state,_options._tolerance); - if(!lhs.advance()) break; // if we can't advance LHS, we're done for this batch + // Only update if we have up-to-date information for the LHS row + if (is_up_to_date_for_lhs_row()) { + dst.emplace(_state, _options._tolerance); + if (!lhs.advance()) break; // if we can't advance LHS, we're done for this batch } else { - if((!any_advanced) && (_state.size()>1)) break; // need to wait for new data + if ((!any_advanced) && (_state.size() > 1)) break; // need to wait for new data } } - // Prune memo entries that have expired (to bound memory consumption) - if(!lhs.empty()) - for(size_t i=1;i<_state.size();++i) - _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time()-_options._tolerance); + // Prune memo entries that have expired (to bound memory consumption) + if (!lhs.empty()) + for (size_t i = 1; i < _state.size(); ++i) + _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - + _options._tolerance); // Emit the batch - std::shared_ptr r=dst.empty()?nullptr:dst.materialize(output_schema(),_state); + std::shared_ptr r = + dst.empty() ? nullptr : dst.materialize(output_schema(), _state); return r; } @@ -526,15 +560,18 @@ class AsofJoinNode : public ExecNode { std::cerr << "process() begin\n"; std::lock_guard guard(_gate); - if(finished_.is_finished()) { std::cerr << "InputReceived EARLYEND\n"; return; } + if (finished_.is_finished()) { + std::cerr << "InputReceived EARLYEND\n"; + return; + } // Process batches while we have data - for(;;) { - std::shared_ptr out_rb=process_inner(); - if(!out_rb) break; + for (;;) { + std::shared_ptr out_rb = process_inner(); + if (!out_rb) break; ++_progress_batches_produced; ExecBatch out_b(*out_rb); - outputs_[0]->InputReceived(this,std::move(out_b)); + outputs_[0]->InputReceived(this, std::move(out_b)); } std::cerr << "process() end\n"; @@ -544,13 +581,13 @@ class AsofJoinNode : public ExecNode { // // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) - if(_state.at(0)->finished()) { - //cerr << "LHS is finished\n"; - _total_batches_produced=util::make_optional(_progress_batches_produced); + if (_state.at(0)->finished()) { + // cerr << "LHS is finished\n"; + _total_batches_produced = util::make_optional(_progress_batches_produced); std::cerr << "process() finished " << *_total_batches_produced << "\n"; StopProducing(); assert(_total_batches_produced.has_value()); - outputs_[0]->InputFinished(this,*_total_batches_produced); + outputs_[0]->InputFinished(this, *_total_batches_produced); } } @@ -581,60 +618,53 @@ class AsofJoinNode : public ExecNode { void process_thread() { std::cerr << "AsOfMergeNode::process_thread started.\n"; - for(;;) { - if(!_process.pop()) { - std::cerr << "AsOfMergeNode::process_thread done.\n"; - return; + for (;;) { + if (!_process.pop()) { + std::cerr << "AsOfMergeNode::process_thread done.\n"; + return; } - //cerr << "AsOfMergeNode::process() BEGIN\n"; process(); - //cerr << "AsOfMergeNode::process() END\n"; } } - - static void process_thread_wrapper(AsofJoinNode *node) { - node->process_thread(); - } + + static void process_thread_wrapper(AsofJoinNode* node) { node->process_thread(); } public: - AsofJoinNode(ExecPlan* plan, - NodeVector inputs, - const AsofJoinNodeOptions& join_options, + AsofJoinNode(ExecPlan* plan, NodeVector inputs, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, - std::unique_ptr impl - ); + std::unique_ptr impl); virtual ~AsofJoinNode() { - _process.push(false); // poison pill + _process.push(false); // poison pill _process_thread.join(); } - - static arrow::Result Make(ExecPlan *plan, std::vector inputs, - const ExecNodeOptions &options) { + + static arrow::Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { std::unique_ptr schema_mgr = - ::arrow::internal::make_unique(); + ::arrow::internal::make_unique(); const auto& join_options = checked_cast(options); - std::shared_ptr output_schema = schema_mgr->MakeOutputSchema(inputs, join_options); + std::shared_ptr output_schema = + schema_mgr->MakeOutputSchema(inputs, join_options); ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, AsofJoinImpl::MakeBasic()); - return plan->EmplaceNode( - plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr), - std::move(impl) - ); + return plan->EmplaceNode(plan, inputs, join_options, + std::move(output_schema), + std::move(schema_mgr), std::move(impl)); } const char* kind_name() const override { return "AsofJoinNode"; } void InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input - ARROW_DCHECK(std::find(inputs_.begin(),inputs_.end(),input)!=inputs_.end()); - size_t k=std::find(inputs_.begin(),inputs_.end(),input)-inputs_.begin(); - std::cerr << "InputReceived BEGIN (k="<output_schema()); + auto rb = *batch.ToRecordBatch(input->output_schema()); _state.at(k)->push(rb); _process.push(true); @@ -651,9 +681,9 @@ class AsofJoinNode : public ExecNode { { std::lock_guard guard(_gate); std::cerr << "InputFinished find\n"; - ARROW_DCHECK(std::find(inputs_.begin(),inputs_.end(),input)!=inputs_.end()); - size_t k=std::find(inputs_.begin(),inputs_.end(),input)-inputs_.begin(); - //cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + // cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; _state.at(k)->set_total_batches(total_batches); // is_finished=_state.at(k)->finished(); } @@ -666,15 +696,18 @@ class AsofJoinNode : public ExecNode { std::cerr << "InputFinished END\n"; } Status StartProducing() override { - std::cout << "StartProducing" << "\n"; - finished_=arrow::Future<>::Make(); + std::cout << "StartProducing" + << "\n"; + finished_ = arrow::Future<>::Make(); // bool use_sync_execution = !(plan_->exec_context()->executor()); // std::cerr << "StartScheduling\n"; // std::cerr << "use_sync_execution: " << use_sync_execution << std::endl; // RETURN_NOT_OK( // scheduler_->StartScheduling(0 /*thread index*/, - // std::move([this](std::function func) -> Status { - // return this->ScheduleTaskCallback(std::move(func)); + // std::move([this](std::function + // func) -> Status { + // return + // this->ScheduleTaskCallback(std::move(func)); // }), // 1, // use_sync_execution @@ -687,21 +720,24 @@ class AsofJoinNode : public ExecNode { return Status::OK(); } void PauseProducing(ExecNode* output) override { - std::cout << "PauseProducing" << "\n"; + std::cout << "PauseProducing" + << "\n"; } void ResumeProducing(ExecNode* output) override { - std::cout << "ResumeProducing" << "\n"; + std::cout << "ResumeProducing" + << "\n"; } void StopProducing(ExecNode* output) override { - DCHECK_EQ(output,outputs_[0]); + DCHECK_EQ(output, outputs_[0]); StopProducing(); - std::cout << "StopProducing" << "\n"; + std::cout << "StopProducing" + << "\n"; } void StopProducing() override { std::cerr << "StopProducing" << std::endl; - //if(batch_count_.Cancel()) finished_.MarkFinished(); - finished_.MarkFinished(); - for(auto&& input: inputs_) input->StopProducing(this); + // if(batch_count_.Cancel()) finished_.MarkFinished(); + finished_.MarkFinished(); + for (auto&& input : inputs_) input->StopProducing(this); } Future<> finished() override { return finished_; } @@ -728,12 +764,15 @@ class AsofJoinNode : public ExecNode { std::unique_ptr schema_mgr_; std::unique_ptr impl_; Future<> finished_; + // InputStates + // Each input state correponds to an input table + // std::vector> _state; std::mutex _gate; AsofJoinNodeOptions _options; - //ThreadIndexer thread_indexer_; - //std::unique_ptr scheduler_; + // ThreadIndexer thread_indexer_; + // std::unique_ptr scheduler_; // int task_group_process_; // Queue for triggering processing of a given input @@ -746,11 +785,11 @@ class AsofJoinNode : public ExecNode { util::optional _total_batches_produced; // In-progress batches produced - int _progress_batches_produced=0; + int _progress_batches_produced = 0; }; -std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector& inputs, - const AsofJoinNodeOptions& options) { +std::shared_ptr AsofJoinSchema::MakeOutputSchema( + const std::vector& inputs, const AsofJoinNodeOptions& options) { std::vector> fields; assert(inputs.size() > 1); @@ -761,15 +800,15 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vectoroutput_schema()->num_fields(); ++i) + for (int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) fields.push_back(inputs[0]->output_schema()->field(i)); // Take all non-key, non-time RHS fields - for(size_t j = 1; j < inputs.size(); ++j) { - const auto &input_schema = inputs[j]->output_schema(); - for(int i = 0; i < input_schema->num_fields(); ++i) { - const auto &name = input_schema->field(i)->name(); - if((name!= *options.keys.name()) && (name!= *options.time.name())) { + for (size_t j = 1; j < inputs.size(); ++j) { + const auto& input_schema = inputs[j]->output_schema(); + for (int i = 0; i < input_schema->num_fields(); ++i) { + const auto& name = input_schema->field(i)->name(); + if ((name != *options.keys.name()) && (name != *options.time.name())) { fields.push_back(input_schema->field(i)); } } @@ -779,52 +818,53 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema(const std::vector(fields); } - -AsofJoinNode::AsofJoinNode(ExecPlan* plan, - NodeVector inputs, +AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, - std::unique_ptr impl - ) - : ExecNode(plan, inputs, {"left", "right"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), - impl_(std::move(impl)), - _options(join_options), - _process(1), - _process_thread(&AsofJoinNode::process_thread_wrapper, this) -{ - std::cout << "AsofJoinNode created" << "\n"; - - for(size_t i=0;i(inputs[i]->output_schema(), - *_options.time.name(), - *_options.keys.name(), - util::make_optional(0) /*TODO: make wildcard configuirable*/)); - size_t dst_offset=0; - for(auto &state:_state) - dst_offset=state->init_src_to_dst_mapping(dst_offset,!!dst_offset); - - finished_ = Future<>::MakeFinished(); - - // scheduler_ = TaskScheduler::Make(); - // task_group_process_ = scheduler_->RegisterTaskGroup( - // [this](size_t thread_index, int64_t task_id) -> Status { - // return process_thread(thread_index, task_id); - // }, - // [this](size_t thread_index) -> Status { - // return process_finished(thread_index); - // } - // ); - // scheduler_->RegisterEnd(); + std::unique_ptr impl) + : ExecNode(plan, inputs, {"left", "right"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + impl_(std::move(impl)), + _options(join_options), + _process(1), + _process_thread(&AsofJoinNode::process_thread_wrapper, this) { + std::cout << "AsofJoinNode created" + << "\n"; + + for (size_t i = 0; i < inputs.size(); ++i) + _state.push_back(::arrow::internal::make_unique( + inputs[i]->output_schema(), *_options.time.name(), *_options.keys.name(), + util::make_optional(0) /*TODO: make wildcard configuirable*/)); + size_t dst_offset = 0; + for (auto& state : _state) + dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); + + finished_ = Future<>::MakeFinished(); + + // scheduler_ = TaskScheduler::Make(); + // task_group_process_ = scheduler_->RegisterTaskGroup( + // [this](size_t thread_index, + // int64_t task_id) -> Status { + // return + // process_thread(thread_index, + // task_id); + // }, + // [this](size_t thread_index) -> + // Status { + // return + // process_finished(thread_index); + // } + // ); + // scheduler_->RegisterEnd(); } namespace internal { - void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { - DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); - } +void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); } +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 51d90355f9e..8286a317953 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -65,93 +65,74 @@ BatchesWithSchema GenerateBatchesFromString( void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r_batches, - const BatchesWithSchema& exp_batches, - const FieldRef time, - const FieldRef keys) { - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), nullptr - ); + const BatchesWithSchema& exp_batches, const FieldRef time, + const FieldRef keys, const long tolerance) { + auto exec_ctx = + arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options(time, keys, 0); + AsofJoinNodeOptions join_options(time, keys, tolerance); Declaration join{"asofjoin", join_options}; - join.inputs.emplace_back( - Declaration{"source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); - join.inputs.emplace_back( - Declaration{"source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) - .AddToPlan(plan.get())); + .AddToPlan(plan.get())); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); ASSERT_OK_AND_ASSIGN(auto exp_table, TableFromExecBatches(exp_batches.schema, exp_batches.batches)); - ASSERT_OK_AND_ASSIGN(auto res_table, - TableFromExecBatches(exp_batches.schema, res)); - + ASSERT_OK_AND_ASSIGN(auto res_table, TableFromExecBatches(exp_batches.schema, res)); AssertTablesEqual(*exp_table, *res_table, /*same_chunk_layout=*/false, /*flatten=*/true); - std::cerr << "Result Equals" << "\n"; + std::cerr << "Result Equals" + << "\n"; } void RunNonEmptyTest(bool exact_matches) { - auto l_schema = schema( - { - field("time", int64()), - field("key", int32()), - field("l_v0", float64()) - } - ); - auto r_schema = schema( - { - field("time", int64()), - field("key", int32()), - field("r_v0", float64()) - } - ); + auto l_schema = + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); + auto r_schema = + schema({field("time", int64()), field("key", int32()), field("r_v0", float64())}); auto exp_schema = schema({ - field("time", int64()), - field("key", int32()), - field("l_v0", float64()), - field("r_v0", float64()), - }); + field("time", int64()), + field("key", int32()), + field("l_v0", float64()), + field("r_v0", float64()), + }); BatchesWithSchema l_batches, r_batches, exp_batches; - l_batches = GenerateBatchesFromString( - l_schema, - {R"([[0, 1, 1.0]])"} - ); - r_batches = GenerateBatchesFromString( - r_schema, - {R"([[0, 1, 2.0]])"} - ); - exp_batches = GenerateBatchesFromString( - exp_schema, - {R"([[0, 1, 1.0, 2.0]])"} - ); - - CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key"); -} + l_batches = GenerateBatchesFromString(l_schema, {R"([[1000, 1, 1.0]])"}); + r_batches = GenerateBatchesFromString(r_schema, {R"([[0, 1, 2.0]])"}); + exp_batches = GenerateBatchesFromString(exp_schema, {R"([[1000, 1, 1.0, 2.0]])"}); - class AsofJoinTest : public testing::TestWithParam> {}; + CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key", 1000); -INSTANTIATE_TEST_SUITE_P( - AsofJoinTest, AsofJoinTest, - ::testing::Combine( - ::testing::Values(false, true) - )); + l_batches = GenerateBatchesFromString(l_schema, {R"([[1000, 1, 1.0]])"}); + r_batches = GenerateBatchesFromString(r_schema, {R"([[0, 1, 2.0]])"}); + // This is wrong + // TODO: Fix null values in the result + exp_batches = GenerateBatchesFromString(exp_schema, {R"([[1000, 1, 1.0, 0.0]])"}); -TEST_P(AsofJoinTest, TestExactMatches) { - RunNonEmptyTest(std::get<0>(GetParam())); + CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key", 999); } +class AsofJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P(AsofJoinTest, AsofJoinTest, + ::testing::Combine(::testing::Values(false, true))); + +TEST_P(AsofJoinTest, TestExactMatches) { RunNonEmptyTest(std::get<0>(GetParam())); } + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h b/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h index 1b14cdb6eeb..5451ad08e1f 100644 --- a/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h +++ b/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h @@ -1,72 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #pragma once -#include #include +#include #include #include #include namespace arrow { - namespace compute { +namespace compute { template class concurrent_bounded_queue { - size_t _remaining; - std::vector _buffer; - mutable std::mutex _gate; - std::condition_variable _not_full; - std::condition_variable _not_empty; + size_t _remaining; + std::vector _buffer; + mutable std::mutex _gate; + std::condition_variable _not_full; + std::condition_variable _not_empty; + + size_t _next_push = 0; + size_t _next_pop = 0; - size_t _next_push=0; - size_t _next_pop=0; -public: - concurrent_bounded_queue(size_t capacity) : _remaining(capacity),_buffer(capacity) { - } - // Push new value to queue, waiting for capacity indefinitely. - void push(const T &t) { - std::unique_lock lock(_gate); - _not_full.wait(lock,[&]{return _remaining>0;}); - _buffer[_next_push++]=t; - _next_push%=_buffer.size(); - --_remaining; - _not_empty.notify_one(); - } - // Get oldest value from queue, or wait indefinitely for it. - T pop() { - std::unique_lock lock(_gate); - _not_empty.wait(lock,[&]{return _remaining<_buffer.size();}); - T r=_buffer[_next_pop++]; - _next_pop%=_buffer.size(); - ++_remaining; - _not_full.notify_one(); - return r; - } - // Try to pop the oldest value from the queue (or return nullopt if none) - util::optional try_pop() { - std::unique_lock lock(_gate); - if(_remaining==_buffer.size()) return util::nullopt; - T r=_buffer[_next_pop++]; - _next_pop%=_buffer.size(); - ++_remaining; - _not_full.notify_one(); - return r; - } + public: + concurrent_bounded_queue(size_t capacity) : _remaining(capacity), _buffer(capacity) {} + // Push new value to queue, waiting for capacity indefinitely. + void push(const T& t) { + std::unique_lock lock(_gate); + _not_full.wait(lock, [&] { return _remaining > 0; }); + _buffer[_next_push++] = t; + _next_push %= _buffer.size(); + --_remaining; + _not_empty.notify_one(); + } + // Get oldest value from queue, or wait indefinitely for it. + T pop() { + std::unique_lock lock(_gate); + _not_empty.wait(lock, [&] { return _remaining < _buffer.size(); }); + T r = _buffer[_next_pop++]; + _next_pop %= _buffer.size(); + ++_remaining; + _not_full.notify_one(); + return r; + } + // Try to pop the oldest value from the queue (or return nullopt if none) + util::optional try_pop() { + std::unique_lock lock(_gate); + if (_remaining == _buffer.size()) return util::nullopt; + T r = _buffer[_next_pop++]; + _next_pop %= _buffer.size(); + ++_remaining; + _not_full.notify_one(); + return r; + } - // Test whether empty - bool empty()const { - std::unique_lock lock(_gate); - return _remaining==_buffer.size(); - } + // Test whether empty + bool empty() const { + std::unique_lock lock(_gate); + return _remaining == _buffer.size(); + } - // Un-synchronized access to front - // For this to be "safe": - // 1) the caller logically guarantees that queue is not empty - // 2) pop/try_pop cannot be called concurrently with this - const T &unsync_front()const { - return _buffer[_next_pop]; - } + // Un-synchronized access to front + // For this to be "safe": + // 1) the caller logically guarantees that queue is not empty + // 2) pop/try_pop cannot be called concurrently with this + const T& unsync_front() const { return _buffer[_next_pop]; } }; - } // namespace compute -} // namespace arrow +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 383414a0a45..9817b9b2180 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -361,11 +361,10 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { Expression filter; }; - class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef time, FieldRef keys, int64_t tolerance) : - time(std::move(time)), keys(std::move(keys)), _tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef time, FieldRef keys, int64_t tolerance) + : time(std::move(time)), keys(std::move(keys)), _tolerance(tolerance) {} // time column FieldRef time; From 4c334521677facfe7ee1a0e38256a6bd3294dad5 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 9 May 2022 16:32:57 -0400 Subject: [PATCH 07/47] Add support for mutliple tables; Add more tests --- cpp/src/arrow/compute/exec/asof_join_node.cc | 22 ++- .../arrow/compute/exec/asof_join_node_test.cc | 137 +++++++++++++++--- cpp/src/arrow/compute/exec/exec_plan.cc | 2 + 3 files changed, 134 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2a57c0c1f47..702080f5df3 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -617,10 +617,10 @@ class AsofJoinNode : public ExecNode { // } void process_thread() { - std::cerr << "AsOfMergeNode::process_thread started.\n"; + std::cerr << "AsofJoinNode::process_thread started.\n"; for (;;) { if (!_process.pop()) { - std::cerr << "AsOfMergeNode::process_thread done.\n"; + std::cerr << "AsofJoinNode::process_thread done.\n"; return; } process(); @@ -630,7 +630,9 @@ class AsofJoinNode : public ExecNode { static void process_thread_wrapper(AsofJoinNode* node) { node->process_thread(); } public: - AsofJoinNode(ExecPlan* plan, NodeVector inputs, const AsofJoinNodeOptions& join_options, + AsofJoinNode(ExecPlan* plan, NodeVector inputs, + std::vector input_labels, + const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, std::unique_ptr impl); @@ -650,7 +652,16 @@ class AsofJoinNode : public ExecNode { schema_mgr->MakeOutputSchema(inputs, join_options); ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, AsofJoinImpl::MakeBasic()); - return plan->EmplaceNode(plan, inputs, join_options, + std::cerr << "Input size: " << inputs.size() << "\n"; + std::vector input_labels(inputs.size()); + input_labels[0] = "left"; + for(size_t i = 1; i < inputs.size(); ++i) { + input_labels[i] = "right_" + std::to_string(i); + } + + return plan->EmplaceNode(plan, inputs, + std::move(input_labels), + join_options, std::move(output_schema), std::move(schema_mgr), std::move(impl)); } @@ -819,11 +830,12 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema( } AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, + std::vector input_labels, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, std::unique_ptr impl) - : ExecNode(plan, inputs, {"left", "right"}, + : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), impl_(std::move(impl)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 8286a317953..bc7e7b0b633 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -64,7 +64,8 @@ BatchesWithSchema GenerateBatchesFromString( } void CheckRunOutput(const BatchesWithSchema& l_batches, - const BatchesWithSchema& r_batches, + const BatchesWithSchema& r0_batches, + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, const FieldRef time, const FieldRef keys, const long tolerance) { auto exec_ctx = @@ -77,7 +78,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); join.inputs.emplace_back(Declaration{ - "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); + "source", SourceNodeOptions{r0_batches.schema, r0_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r1_batches.schema, r1_batches.gen(false, false)}}); AsyncGenerator> sink_gen; @@ -94,37 +97,127 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, AssertTablesEqual(*exp_table, *res_table, /*same_chunk_layout=*/false, /*flatten=*/true); - std::cerr << "Result Equals" - << "\n"; + std::cerr << "Result Equals" << "\n"; } void RunNonEmptyTest(bool exact_matches) { auto l_schema = schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); - auto r_schema = - schema({field("time", int64()), field("key", int32()), field("r_v0", float64())}); + auto r0_schema = + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); + auto r1_schema = + schema({field("time", int64()), field("key", int32()), field("r1_v0", float64())}); + auto exp_schema = schema({ field("time", int64()), field("key", int32()), field("l_v0", float64()), - field("r_v0", float64()), + field("r0_v0", float64()), + field("r1_v0", float64()), }); - BatchesWithSchema l_batches, r_batches, exp_batches; - - l_batches = GenerateBatchesFromString(l_schema, {R"([[1000, 1, 1.0]])"}); - r_batches = GenerateBatchesFromString(r_schema, {R"([[0, 1, 2.0]])"}); - exp_batches = GenerateBatchesFromString(exp_schema, {R"([[1000, 1, 1.0, 2.0]])"}); - - CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key", 1000); - - l_batches = GenerateBatchesFromString(l_schema, {R"([[1000, 1, 1.0]])"}); - r_batches = GenerateBatchesFromString(r_schema, {R"([[0, 1, 2.0]])"}); - // This is wrong - // TODO: Fix null values in the result - exp_batches = GenerateBatchesFromString(exp_schema, {R"([[1000, 1, 1.0, 0.0]])"}); - - CheckRunOutput(l_batches, r_batches, exp_batches, "time", "key", 999); + // Test three table join + BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; + + // Single key, single batch + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0], [1000, 1, 2.0]])"} + ); + r0_batches = GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])"}); + r1_batches = GenerateBatchesFromString(r1_schema, {R"([[1000, 1, 101.0]])"}); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [1000, 1, 2.0, 11.0, 101.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + + // Single key, multiple batches + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"} + ); + r0_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"} + ); + r1_batches = GenerateBatchesFromString(r1_schema, + {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"} + ); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + + // Single key, multiple left batches, single right batches + + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"} + ); + + r0_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 1, 11.0], [1000, 1, 12.0]])"} + ); + r1_batches = GenerateBatchesFromString(r1_schema, + {R"([[0, 1, 101.0], [1000, 1, 102.0]])"} + ); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + + // Multi key, multiple batches, misaligned batches + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} + ); + r0_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" + }); + r1_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" + }); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + + // Multi key, multiple batches, misaligned batches, smaller tolerance + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} + ); + r0_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" + }); + r1_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" + }); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 0.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 500); + + // Multi key, multiple batches, misaligned batches, 0 tolerance + l_batches = GenerateBatchesFromString(l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} + ); + r0_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" + }); + r1_batches = GenerateBatchesFromString(r0_schema, + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" + }); + exp_batches = GenerateBatchesFromString(exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 0.0, 101.0], [1000, 2, 22.0, 0.0, 0.0], [1500, 1, 3.0, 0.0, 0.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 0.0, 0.0]])"} + ); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 0); } class AsofJoinTest : public testing::TestWithParam> {}; diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 95e8953065e..33f98f69a90 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,6 +17,7 @@ #include "arrow/compute/exec/exec_plan.h" +#include #include #include #include @@ -75,6 +76,7 @@ struct ExecPlanImpl : public ExecPlan { return Status::Invalid("ExecPlan has no node"); } for (const auto& node : nodes_) { + std::cerr << node->kind_name() << "\n"; RETURN_NOT_OK(node->Validate()); } return Status::OK(); From c6c6093b10c6465b3c7d4aa55d68d7e70dfeaa0e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 May 2022 15:52:11 -0400 Subject: [PATCH 08/47] Clean up code style (Pass ninja lint now), switch to unbounded queue --- cpp/src/arrow/compute/exec/asof_join.cc | 2 - cpp/src/arrow/compute/exec/asof_join.h | 2 - cpp/src/arrow/compute/exec/asof_join_node.cc | 333 ++++++++---------- .../arrow/compute/exec/asof_join_node_test.cc | 146 ++++---- .../compute/exec/concurrent_bounded_queue.h | 87 ----- 5 files changed, 203 insertions(+), 367 deletions(-) delete mode 100644 cpp/src/arrow/compute/exec/concurrent_bounded_queue.h diff --git a/cpp/src/arrow/compute/exec/asof_join.cc b/cpp/src/arrow/compute/exec/asof_join.cc index d214ef3078e..e3d541e0d4b 100644 --- a/cpp/src/arrow/compute/exec/asof_join.cc +++ b/cpp/src/arrow/compute/exec/asof_join.cc @@ -43,8 +43,6 @@ #include -#include "concurrent_bounded_queue.h" - namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index ee5492bcc18..4ed71c8f76e 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -35,8 +35,6 @@ #include "arrow/type.h" #include "arrow/util/tracing_internal.h" -#include "concurrent_bounded_queue.h" - namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 702080f5df3..07a15baec60 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -35,15 +35,11 @@ #include //#include #include -#include -#include #include #include #include -#include #include // so we don't need to require C++20 #include -#include #include #include #include @@ -53,11 +49,58 @@ #include -#include "concurrent_bounded_queue.h" - namespace arrow { namespace compute { +/** + * Simple implementation for an unbound concurrent queue + */ +template +class concurrent_queue { + public: + T pop() { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&] { return !queue_.empty();}); + auto item = queue_.front(); + queue_.pop(); + return item; + } + + void push(const T& item) { + std::unique_lock lock(mutex_); + queue_.push(item); + cond_.notify_one(); + } + + util::optional try_pop() { + // Try to pop the oldest value from the queue (or return nullopt if none) + std::unique_lock lock(mutex_); + if (queue_.empty()) { + return util::nullopt; + } else { + auto item = queue_.front(); + queue_.pop(); + return item; + } + } + + bool empty() const { + std::unique_lock lock(mutex_); + return queue_.empty(); + } + + // Un-synchronized access to front + // For this to be "safe": + // 1) the caller logically guarantees that queue is not empty + // 2) pop/try_pop cannot be called concurrently with this + const T& unsync_front() const { return queue_.front(); } + + private: + std::queue queue_; + mutable std::mutex mutex_; + std::condition_variable cond_; +}; + struct MemoStore { // Stores last known values for all the keys @@ -110,82 +153,79 @@ class InputState { // InputState correponds to an input // Input record batches are queued up in InputState until processed and // turned into output record batches. - + public: InputState(const std::shared_ptr& schema, const std::string& time_col_name, const std::string& key_col_name, util::optional wildcard_key) - : _queue(QUEUE_CAPACITY), - _wildcard_key(wildcard_key), - _schema(schema), - _time_col_index( + : queue_(), + wildcard_key_(wildcard_key), + schema_(schema), + time_col_index_( schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name - _key_col_index( - schema->GetFieldIndex(key_col_name)) // TODO: handle missing field name - { /*nothing else*/ - } + key_col_index_(schema->GetFieldIndex(key_col_name)) {} size_t init_src_to_dst_mapping(size_t dst_offset, bool skip_time_and_key_fields) { - src_to_dst.resize(_schema->num_fields()); - for (int i = 0; i < _schema->num_fields(); ++i) + src_to_dst_.resize(schema_->num_fields()); + for (int i = 0; i < schema_->num_fields(); ++i) if (!(skip_time_and_key_fields && is_time_or_key_column(i))) - src_to_dst[i] = dst_offset++; + src_to_dst_[i] = dst_offset++; return dst_offset; } const util::optional& map_src_to_dst(col_index_t src) const { - return src_to_dst[src]; + return src_to_dst_[src]; } bool is_time_or_key_column(col_index_t i) const { - assert(i < _schema->num_fields()); - return (i == _time_col_index) || (i == _key_col_index); + assert(i < schema_->num_fields()); + return (i == time_col_index_) || (i == key_col_index_); } // Gets the latest row index, assuming the queue isn't empty - row_index_t get_latest_row() const { return _latest_ref_row; } + row_index_t get_latest_row() const { return latest_ref_row_; } bool empty() const { - if (_latest_ref_row > 0) + if (latest_ref_row_ > 0) return false; // cannot be empty if ref row is >0 -- can avoid slow queue lock // below - return _queue.empty(); + return queue_.empty(); } - int count_batches_processed() const { return _batches_processed; } - int count_total_batches() const { return _total_batches; } + int countbatches_processed_() const { return batches_processed_; } + int count_total_batches() const { return total_batches_; } // Gets latest batch (precondition: must not be empty) const std::shared_ptr& get_latest_batch() const { - return _queue.unsync_front(); + return queue_.unsync_front(); } KeyType get_latest_key() const { - return _queue.unsync_front() - ->column_data(_key_col_index) - ->GetValues(1)[_latest_ref_row]; + return queue_.unsync_front() + ->column_data(key_col_index_) + ->GetValues(1)[latest_ref_row_]; } int64_t get_latest_time() const { - return _queue.unsync_front() - ->column_data(_time_col_index) - ->GetValues(1)[_latest_ref_row]; + return queue_.unsync_front() + ->column_data(time_col_index_) + ->GetValues(1)[latest_ref_row_]; } - bool finished() const { return _batches_processed == _total_batches; } + bool finished() const { return batches_processed_ == total_batches_; } bool advance() { // Returns true if able to advance, false if not. - + bool have_active_batch = - (_latest_ref_row > 0 /*short circuit the lock on the queue*/) || !_queue.empty(); + (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.empty(); if (have_active_batch) { // If we have an active batch - if (++_latest_ref_row >= _queue.unsync_front()->num_rows()) { + if (++latest_ref_row_ >= queue_.unsync_front()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. - ++_batches_processed; - _latest_ref_row = 0; - have_active_batch &= !_queue.try_pop(); + ++batches_processed_; + latest_ref_row_ = 0; + have_active_batch &= !queue_.try_pop(); if (have_active_batch) - assert(_queue.unsync_front()->num_rows() > 0); // empty batches disallowed + assert(queue_.unsync_front()->num_rows() > 0); // empty batches disallowed } } return have_active_batch; @@ -207,14 +247,15 @@ class InputState { bool updated = false; do { latest_time = get_latest_time(); - if (latest_time <= - ts) // if advance() returns true, then the latest_ts must also be valid - // Keep advancing right table until we hit the latest row that has - // timestamp <= ts. This is because we only need the latest row for the - // match given a left ts. However, for a futre - _memo.store(get_latest_batch(), _latest_ref_row, latest_time, get_latest_key()); - else + // if advance() returns true, then the latest_ts must also be valid + // Keep advancing right table until we hit the latest row that has + // timestamp <= ts. This is because we only need the latest row for the + // match given a left ts. + if (latest_time <= ts) { + memo_.store(get_latest_batch(), latest_ref_row_, latest_time, get_latest_key()); + } else { break; // hit a future timestamp -- done updating for now + } updated = true; } while (advance()); return updated; @@ -222,16 +263,16 @@ class InputState { void push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { - _queue.push(rb); + queue_.push(rb); } else { - ++_batches_processed; // don't enqueue empty batches, just record as processed + ++batches_processed_; // don't enqueue empty batches, just record as processed } } util::optional get_memo_entry_for_key(KeyType key) { - auto r = _memo.get_entry_for_key(key); + auto r = memo_.get_entry_for_key(key); if (r.has_value()) return r; - if (_wildcard_key.has_value()) r = _memo.get_entry_for_key(*_wildcard_key); + if (wildcard_key_.has_value()) r = memo_.get_entry_for_key(*wildcard_key_); return r; } @@ -241,49 +282,48 @@ class InputState { } void remove_memo_entries_with_lesser_time(int64_t ts) { - _memo.remove_entries_with_lesser_time(ts); + memo_.remove_entries_with_lesser_time(ts); } - const std::shared_ptr& get_schema() const { return _schema; } + const std::shared_ptr& get_schema() const { return schema_; } void set_total_batches(int n) { - assert(n >= - 0); // not sure why arrow uses a signed int for this, but it should be >=0 - assert(_total_batches == -1); // shouldn't be set more than once - _total_batches = n; + assert(n >= 0); + assert(total_batches_ == -1); // shouldn't be set more than once + total_batches_ = n; } private: // Pending record batches. The latest is the front. Batches cannot be empty. - concurrent_bounded_queue> _queue; + concurrent_queue> queue_; // Wildcard key for this input, if applicable. - util::optional _wildcard_key; - + util::optional wildcard_key_; + // Schema associated with the input - std::shared_ptr _schema; - + std::shared_ptr schema_; + // Total number of batches (only int because InputFinished uses int) - int _total_batches = -1; - + int total_batches_ = -1; + // Number of batches processed so far (only int because InputFinished uses int) - int _batches_processed = 0; - + int batches_processed_ = 0; + // Index of the time col - col_index_t _time_col_index; - + col_index_t time_col_index_; + // Index of the key col - col_index_t _key_col_index; + col_index_t key_col_index_; + + // Index of the latest row reference within; if >0 then queue_ cannot be empty + row_index_t latest_ref_row_ = + 0; // must be < queue_.front()->num_rows() if queue_ is non-empty - // Index of the latest row reference within; if >0 then _queue cannot be empty - row_index_t _latest_ref_row = - 0; // must be < _queue.front()->num_rows() if _queue is non-empty - // Stores latest known values for the various keys - MemoStore _memo; - + MemoStore memo_; + // Mapping of source columns to destination columns - std::vector> src_to_dst; + std::vector> src_to_dst_; }; template @@ -318,11 +358,11 @@ class CompositeReferenceTable { // Adds a RecordBatch ref to the mapping, if needed void add_record_batch_ref(const std::shared_ptr& ref) { - if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; + if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t) ref.get()] = ref; } public: - CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) { + explicit CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) { assert(_n_tables >= 1); assert(_n_tables <= MAX_TABLES); } @@ -545,10 +585,12 @@ class AsofJoinNode : public ExecNode { } // Prune memo entries that have expired (to bound memory consumption) - if (!lhs.empty()) - for (size_t i = 1; i < _state.size(); ++i) + if (!lhs.empty()) { + for (size_t i = 1; i < _state.size(); ++i) { _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - _options._tolerance); + } + } // Emit the batch std::shared_ptr r = @@ -582,40 +624,13 @@ class AsofJoinNode : public ExecNode { // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) if (_state.at(0)->finished()) { - // cerr << "LHS is finished\n"; - _total_batches_produced = util::make_optional(_progress_batches_produced); - std::cerr << "process() finished " << *_total_batches_produced << "\n"; + total_batches_produced_ = util::make_optional(_progress_batches_produced); StopProducing(); - assert(_total_batches_produced.has_value()); - outputs_[0]->InputFinished(this, *_total_batches_produced); + assert(total_batches_produced_.has_value()); + outputs_[0]->InputFinished(this, *total_batches_produced_); } } - // Status process_thread(size_t /*thread_index*/, int64_t /*task_id*/) { - // std::cerr << "AsOfJoinNode::process_thread started.\n"; - // auto result = _process.try_pop(); - - // if (result == util::nullopt) { - // std::cerr << "AsOfJoinNode::process_thread no inputs.\n"; - // return Status::OK(); - // } else { - // if (result.value()) { - // std::cerr << "AsOfJoinNode::process_thread process.\n"; - // process(); - // } else { - // std::cerr << "AsOfJoinNode::process_thread done.\n"; - // return Status::OK(); - // } - // } - - // return Status::OK(); - // } - - // Status process_finished(size_t /*thread_index*/) { - // std::cerr << "AsOfJoinNode::process_finished started.\n"; - // return Status::OK(); - // } - void process_thread() { std::cerr << "AsofJoinNode::process_thread started.\n"; for (;;) { @@ -630,9 +645,8 @@ class AsofJoinNode : public ExecNode { static void process_thread_wrapper(AsofJoinNode* node) { node->process_thread(); } public: - AsofJoinNode(ExecPlan* plan, NodeVector inputs, - std::vector input_labels, - const AsofJoinNodeOptions& join_options, + AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, + const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, std::unique_ptr impl); @@ -652,17 +666,14 @@ class AsofJoinNode : public ExecNode { schema_mgr->MakeOutputSchema(inputs, join_options); ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, AsofJoinImpl::MakeBasic()); - std::cerr << "Input size: " << inputs.size() << "\n"; std::vector input_labels(inputs.size()); input_labels[0] = "left"; - for(size_t i = 1; i < inputs.size(); ++i) { + for (size_t i = 1; i < inputs.size(); ++i) { input_labels[i] = "right_" + std::to_string(i); - } - - return plan->EmplaceNode(plan, inputs, - std::move(input_labels), - join_options, - std::move(output_schema), + } + + return plan->EmplaceNode(plan, inputs, std::move(input_labels), + join_options, std::move(output_schema), std::move(schema_mgr), std::move(impl)); } @@ -696,7 +707,6 @@ class AsofJoinNode : public ExecNode { size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); // cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; _state.at(k)->set_total_batches(total_batches); - // is_finished=_state.at(k)->finished(); } // Trigger a process call // The reason for this is that there are cases at the end of a table where we don't @@ -710,24 +720,6 @@ class AsofJoinNode : public ExecNode { std::cout << "StartProducing" << "\n"; finished_ = arrow::Future<>::Make(); - // bool use_sync_execution = !(plan_->exec_context()->executor()); - // std::cerr << "StartScheduling\n"; - // std::cerr << "use_sync_execution: " << use_sync_execution << std::endl; - // RETURN_NOT_OK( - // scheduler_->StartScheduling(0 /*thread index*/, - // std::move([this](std::function - // func) -> Status { - // return - // this->ScheduleTaskCallback(std::move(func)); - // }), - // 1, - // use_sync_execution - // ) - // ); - // RETURN_NOT_OK( - // scheduler_->StartTaskGroup(0, task_group_process_, 1) - // ); - // std::cerr << "StartScheduling done\n"; return Status::OK(); } void PauseProducing(ExecNode* output) override { @@ -752,48 +744,25 @@ class AsofJoinNode : public ExecNode { } Future<> finished() override { return finished_; } - // Status ScheduleTaskCallback(std::function func) { - // auto executor = plan_->exec_context()->executor(); - // if (executor) { - // RETURN_NOT_OK(executor->Spawn([this, func] { - // size_t thread_index = thread_indexer_(); - // Status status = func(thread_index); - // if (!status.ok()) { - // StopProducing(); - // ErrorIfNotOk(status); - // return; - // } - // })); - // } else { - // // We should not get here in serial execution mode - // ARROW_DCHECK(false); - // } - // return Status::OK(); - // } - private: std::unique_ptr schema_mgr_; std::unique_ptr impl_; Future<> finished_; // InputStates // Each input state correponds to an input table - // + // std::vector> _state; std::mutex _gate; AsofJoinNodeOptions _options; - // ThreadIndexer thread_indexer_; - // std::unique_ptr scheduler_; - // int task_group_process_; - // Queue for triggering processing of a given input // (a false value is a poison pill) - concurrent_bounded_queue _process; + concurrent_queue _process; // Worker thread std::thread _process_thread; // Total batches produced, once we've finished -- only known at completion time. - util::optional _total_batches_produced; + util::optional total_batches_produced_; // In-progress batches produced int _progress_batches_produced = 0; @@ -804,12 +773,6 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema( std::vector> fields; assert(inputs.size() > 1); - // TODO: Deal with multi keys - // std::vector keys; - // for (auto f: options.keys) { - // keys.emplace_back(*f.name()); - // } - // Directly map LHS fields for (int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) fields.push_back(inputs[0]->output_schema()->field(i)); @@ -825,12 +788,11 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema( } } - // Combine into a schema return std::make_shared(fields); } AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, - std::vector input_labels, + std::vector input_labels, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, std::unique_ptr schema_mgr, @@ -840,11 +802,8 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, /*num_outputs=*/1), impl_(std::move(impl)), _options(join_options), - _process(1), + _process(), _process_thread(&AsofJoinNode::process_thread_wrapper, this) { - std::cout << "AsofJoinNode created" - << "\n"; - for (size_t i = 0; i < inputs.size(); ++i) _state.push_back(::arrow::internal::make_unique( inputs[i]->output_schema(), *_options.time.name(), *_options.keys.name(), @@ -854,22 +813,6 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); finished_ = Future<>::MakeFinished(); - - // scheduler_ = TaskScheduler::Make(); - // task_group_process_ = scheduler_->RegisterTaskGroup( - // [this](size_t thread_index, - // int64_t task_id) -> Status { - // return - // process_thread(thread_index, - // task_id); - // }, - // [this](size_t thread_index) -> - // Status { - // return - // process_finished(thread_index); - // } - // ); - // scheduler_->RegisterEnd(); } namespace internal { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index bc7e7b0b633..ee6a6da9567 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -65,9 +65,9 @@ BatchesWithSchema GenerateBatchesFromString( void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, - const BatchesWithSchema& r1_batches, + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, const FieldRef time, - const FieldRef keys, const long tolerance) { + const FieldRef keys, const int64_t tolerance) { auto exec_ctx = arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); @@ -97,7 +97,8 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, AssertTablesEqual(*exp_table, *res_table, /*same_chunk_layout=*/false, /*flatten=*/true); - std::cerr << "Result Equals" << "\n"; + std::cerr << "Result Equals" + << "\n"; } void RunNonEmptyTest(bool exact_matches) { @@ -120,103 +121,86 @@ void RunNonEmptyTest(bool exact_matches) { BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; // Single key, single batch - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0], [1000, 1, 2.0]])"} - ); + l_batches = GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}); r0_batches = GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])"}); r1_batches = GenerateBatchesFromString(r1_schema, {R"([[1000, 1, 101.0]])"}); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [1000, 1, 2.0, 11.0, 101.0]])"} - ); + exp_batches = GenerateBatchesFromString( + exp_schema, {R"([[0, 1, 1.0, 11.0, 0.0], [1000, 1, 2.0, 11.0, 101.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); // Single key, multiple batches - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"} - ); - r0_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"} - ); + l_batches = + GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}); + r0_batches = + GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}); r1_batches = GenerateBatchesFromString(r1_schema, - {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"} - ); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"} - ); + {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}); + exp_batches = GenerateBatchesFromString( + exp_schema, {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); // Single key, multiple left batches, single right batches - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"} - ); + l_batches = + GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}); - r0_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 1, 11.0], [1000, 1, 12.0]])"} - ); - r1_batches = GenerateBatchesFromString(r1_schema, - {R"([[0, 1, 101.0], [1000, 1, 102.0]])"} - ); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"} - ); + r0_batches = + GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}); + r1_batches = + GenerateBatchesFromString(r1_schema, {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}); + exp_batches = GenerateBatchesFromString( + exp_schema, {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); // Multi key, multiple batches, misaligned batches - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} - ); - r0_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" - }); - r1_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" - }); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"} - ); + l_batches = GenerateBatchesFromString( + l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); + r0_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); + r1_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); + exp_batches = GenerateBatchesFromString( + exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); // Multi key, multiple batches, misaligned batches, smaller tolerance - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} - ); - r0_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" - }); - r1_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" - }); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 0.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"} - ); + l_batches = GenerateBatchesFromString( + l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); + r0_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); + r1_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); + exp_batches = GenerateBatchesFromString( + exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 0.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 500); - // Multi key, multiple batches, misaligned batches, 0 tolerance - l_batches = GenerateBatchesFromString(l_schema, - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"} - ); - r0_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])" - }); - r1_batches = GenerateBatchesFromString(r0_schema, - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])" - }); - exp_batches = GenerateBatchesFromString(exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 0.0, 101.0], [1000, 2, 22.0, 0.0, 0.0], [1500, 1, 3.0, 0.0, 0.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 0.0, 0.0]])"} - ); + // Multi key, multiple batches, misaligned batches, zero tolerance + l_batches = GenerateBatchesFromString( + l_schema, + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); + r0_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); + r1_batches = GenerateBatchesFromString( + r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); + exp_batches = GenerateBatchesFromString( + exp_schema, + {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 0.0, 101.0], [1000, 2, 22.0, 0.0, 0.0], [1500, 1, 3.0, 0.0, 0.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 0.0, 0.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 0); } diff --git a/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h b/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h deleted file mode 100644 index 5451ad08e1f..00000000000 --- a/cpp/src/arrow/compute/exec/concurrent_bounded_queue.h +++ /dev/null @@ -1,87 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include - -#include - -namespace arrow { -namespace compute { - -template -class concurrent_bounded_queue { - size_t _remaining; - std::vector _buffer; - mutable std::mutex _gate; - std::condition_variable _not_full; - std::condition_variable _not_empty; - - size_t _next_push = 0; - size_t _next_pop = 0; - - public: - concurrent_bounded_queue(size_t capacity) : _remaining(capacity), _buffer(capacity) {} - // Push new value to queue, waiting for capacity indefinitely. - void push(const T& t) { - std::unique_lock lock(_gate); - _not_full.wait(lock, [&] { return _remaining > 0; }); - _buffer[_next_push++] = t; - _next_push %= _buffer.size(); - --_remaining; - _not_empty.notify_one(); - } - // Get oldest value from queue, or wait indefinitely for it. - T pop() { - std::unique_lock lock(_gate); - _not_empty.wait(lock, [&] { return _remaining < _buffer.size(); }); - T r = _buffer[_next_pop++]; - _next_pop %= _buffer.size(); - ++_remaining; - _not_full.notify_one(); - return r; - } - // Try to pop the oldest value from the queue (or return nullopt if none) - util::optional try_pop() { - std::unique_lock lock(_gate); - if (_remaining == _buffer.size()) return util::nullopt; - T r = _buffer[_next_pop++]; - _next_pop %= _buffer.size(); - ++_remaining; - _not_full.notify_one(); - return r; - } - - // Test whether empty - bool empty() const { - std::unique_lock lock(_gate); - return _remaining == _buffer.size(); - } - - // Un-synchronized access to front - // For this to be "safe": - // 1) the caller logically guarantees that queue is not empty - // 2) pop/try_pop cannot be called concurrently with this - const T& unsync_front() const { return _buffer[_next_pop]; } -}; - -} // namespace compute -} // namespace arrow From 643e368ae27922cfb58eb00c4a070e007ecbd653 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 10 May 2022 16:28:15 -0400 Subject: [PATCH 09/47] Clean up some files --- cpp/src/arrow/CMakeLists.txt | 1 - cpp/src/arrow/compute/exec/asof_join.cc | 57 -------------------- cpp/src/arrow/compute/exec/asof_join.h | 14 ----- cpp/src/arrow/compute/exec/asof_join_node.cc | 19 +++---- cpp/src/arrow/compute/exec/exec_plan.cc | 2 - cpp/src/arrow/compute/exec/options.h | 6 +-- 6 files changed, 10 insertions(+), 89 deletions(-) delete mode 100644 cpp/src/arrow/compute/exec/asof_join.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 0a098df410f..4702a427bad 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -381,7 +381,6 @@ if(ARROW_COMPUTE) compute/cast.cc compute/exec.cc compute/exec/aggregate_node.cc - compute/exec/asof_join.cc compute/exec/asof_join_node.cc compute/exec/bloom_filter.cc compute/exec/exec_plan.cc diff --git a/cpp/src/arrow/compute/exec/asof_join.cc b/cpp/src/arrow/compute/exec/asof_join.cc deleted file mode 100644 index e3d541e0d4b..00000000000 --- a/cpp/src/arrow/compute/exec/asof_join.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/exec/asof_join.h" -#include - -#include -#include -#include -#include -#include -//#include -#include -#include -#include -#include -#include -#include -#include -#include // so we don't need to require C++20 -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace arrow { -namespace compute { - -class AsofJoinBasicImpl : public AsofJoinImpl {}; - -Result> AsofJoinImpl::MakeBasic() { - std::unique_ptr impl{new AsofJoinBasicImpl()}; - return std::move(impl); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index 4ed71c8f76e..54b1f665c78 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -42,15 +42,6 @@ typedef int32_t KeyType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 - -// Capacity of the input queues (for flow control) -// Why 2? -// It needs to be at least 1 to enable progress (otherwise queues have no capacity) -// It needs to be at least 2 to enable addition of a new queue entry while processing -// is being done for another input. -// There's no clear performance benefit to greater than 2. -#define QUEUE_CAPACITY 2 - // The max rows per batch is dictated by the data type for row index #define MAX_ROWS_PER_BATCH 0xFFFFFFFF typedef uint32_t row_index_t; @@ -62,10 +53,5 @@ class AsofJoinSchema { const AsofJoinNodeOptions& options); }; -class AsofJoinImpl { - public: - static Result> MakeBasic(); -}; - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 07a15baec60..8cad9b33ad6 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -60,7 +60,7 @@ class concurrent_queue { public: T pop() { std::unique_lock lock(mutex_); - cond_.wait(lock, [&] { return !queue_.empty();}); + cond_.wait(lock, [&] { return !queue_.empty(); }); auto item = queue_.front(); queue_.pop(); return item; @@ -358,7 +358,7 @@ class CompositeReferenceTable { // Adds a RecordBatch ref to the mapping, if needed void add_record_batch_ref(const std::shared_ptr& ref) { - if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t) ref.get()] = ref; + if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } public: @@ -577,7 +577,7 @@ class AsofJoinNode : public ExecNode { // Only update if we have up-to-date information for the LHS row if (is_up_to_date_for_lhs_row()) { - dst.emplace(_state, _options._tolerance); + dst.emplace(_state, _options.tolerance); if (!lhs.advance()) break; // if we can't advance LHS, we're done for this batch } else { if ((!any_advanced) && (_state.size() > 1)) break; // need to wait for new data @@ -588,7 +588,7 @@ class AsofJoinNode : public ExecNode { if (!lhs.empty()) { for (size_t i = 1; i < _state.size(); ++i) { _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - - _options._tolerance); + _options.tolerance); } } @@ -648,8 +648,7 @@ class AsofJoinNode : public ExecNode { AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, - std::unique_ptr schema_mgr, - std::unique_ptr impl); + std::unique_ptr schema_mgr); virtual ~AsofJoinNode() { _process.push(false); // poison pill @@ -664,7 +663,6 @@ class AsofJoinNode : public ExecNode { const auto& join_options = checked_cast(options); std::shared_ptr output_schema = schema_mgr->MakeOutputSchema(inputs, join_options); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, AsofJoinImpl::MakeBasic()); std::vector input_labels(inputs.size()); input_labels[0] = "left"; @@ -674,7 +672,7 @@ class AsofJoinNode : public ExecNode { return plan->EmplaceNode(plan, inputs, std::move(input_labels), join_options, std::move(output_schema), - std::move(schema_mgr), std::move(impl)); + std::move(schema_mgr)); } const char* kind_name() const override { return "AsofJoinNode"; } @@ -746,7 +744,6 @@ class AsofJoinNode : public ExecNode { private: std::unique_ptr schema_mgr_; - std::unique_ptr impl_; Future<> finished_; // InputStates // Each input state correponds to an input table @@ -795,12 +792,10 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const AsofJoinNodeOptions& join_options, std::shared_ptr output_schema, - std::unique_ptr schema_mgr, - std::unique_ptr impl) + std::unique_ptr schema_mgr) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - impl_(std::move(impl)), _options(join_options), _process(), _process_thread(&AsofJoinNode::process_thread_wrapper, this) { diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 33f98f69a90..95e8953065e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,7 +17,6 @@ #include "arrow/compute/exec/exec_plan.h" -#include #include #include #include @@ -76,7 +75,6 @@ struct ExecPlanImpl : public ExecPlan { return Status::Invalid("ExecPlan has no node"); } for (const auto& node : nodes_) { - std::cerr << node->kind_name() << "\n"; RETURN_NOT_OK(node->Validate()); } return Status::OK(); diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 9817b9b2180..8283fff353d 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -364,14 +364,14 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: AsofJoinNodeOptions(FieldRef time, FieldRef keys, int64_t tolerance) - : time(std::move(time)), keys(std::move(keys)), _tolerance(tolerance) {} + : time(std::move(time)), keys(std::move(keys)), tolerance(tolerance) {} // time column FieldRef time; // keys used for the join. All tables must have the same join key. FieldRef keys; - - int64_t _tolerance; + // tolerance for the inexact timestamp matching in nanoseconds + int64_t tolerance; }; /// \brief Make a node which select top_k/bottom_k rows passed through it From 0781a16ac71f1e17a15a58fc11eb8166f9de7098 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 May 2022 12:15:31 -0400 Subject: [PATCH 10/47] Clean up some files --- cpp/src/arrow/compute/exec/asof_join_node.cc | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 8cad9b33ad6..ffb1157f658 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -347,11 +347,11 @@ struct CompositeReferenceRow { // We don't put the shared_ptr's into the rows for efficiency reasons. template class CompositeReferenceTable { - // Contains shared_ptr refs for all RecordBatches referred to by the contents of _rows + // Contains shared_ptr refs for all RecordBatches referred to by the contents of rows_ std::unordered_map> _ptr2ref; // Row table references - std::vector> _rows; + std::vector> rows_; // Total number of tables in the composite table size_t _n_tables; @@ -367,7 +367,7 @@ class CompositeReferenceTable { assert(_n_tables <= MAX_TABLES); } - size_t n_rows() const { return _rows.size(); } + size_t n_rows() const { return rows_.size(); } // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows @@ -391,13 +391,13 @@ class CompositeReferenceTable { assert(lhs_latest_batch->num_rows() <= MAX_ROWS_PER_BATCH); // TODO: better error handling row_index_t new_batch_size = lhs_latest_batch->num_rows(); - row_index_t new_capacity = _rows.size() + new_batch_size; - if (_rows.capacity() < new_capacity) _rows.reserve(new_capacity); - // cerr << "new_batch_size=" << new_batch_size << " old_size=" << _rows.size() << " - // new_capacity=" << _rows.capacity() << endl; + row_index_t new_capacity = rows_.size() + new_batch_size; + if (rows_.capacity() < new_capacity) rows_.reserve(new_capacity); + // cerr << "new_batch_size=" << new_batch_size << " old_size=" << rows_.size() << " + // new_capacity=" << rows_.capacity() << endl; } - _rows.resize(_rows.size() + 1); - auto& row = _rows.back(); + rows_.resize(rows_.size() + 1); + auto& row = rows_.back(); row._refs[0]._batch = lhs_latest_batch.get(); row._refs[0]._row = lhs_latest_row; add_record_batch_ref(lhs_latest_batch); @@ -428,9 +428,9 @@ class CompositeReferenceTable { std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { Builder builder; // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask - builder.AppendEmptyValues(_rows.size()); - for (row_index_t i_row = 0; i_row < _rows.size(); ++i_row) { - const auto& ref = _rows[i_row]._refs[i_table]; + builder.AppendEmptyValues(rows_.size()); + for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { + const auto& ref = rows_[i_row]._refs[i_table]; if (ref._batch) builder[i_row] = ref._batch->column_data(i_col)->template GetValues( @@ -456,7 +456,7 @@ class CompositeReferenceTable { assert(state.size() >= 1); // Don't build empty batches - size_t n_rows = _rows.size(); + size_t n_rows = rows_.size(); if (!n_rows) return nullptr; // Count output columns (dbg sanitycheck) @@ -518,7 +518,7 @@ class CompositeReferenceTable { } // Returns true if there are no rows - bool empty() const { return _rows.empty(); } + bool empty() const { return rows_.empty(); } }; class AsofJoinNode : public ExecNode { From fc75844c45231a6269b11566852471ed6da1aeec Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 May 2022 16:40:56 -0400 Subject: [PATCH 11/47] Minor clean up --- cpp/src/arrow/compute/exec/asof_join_node.cc | 116 +++++++++---------- 1 file changed, 57 insertions(+), 59 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index ffb1157f658..f9c558d1067 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -329,10 +329,10 @@ class InputState { template struct CompositeReferenceRow { struct Entry { - arrow::RecordBatch* _batch; // can be NULL if there's no value - row_index_t _row; + arrow::RecordBatch* batch; // can be NULL if there's no value + row_index_t row; }; - Entry _refs[MAX_TABLES]; + Entry refs[MAX_TABLES]; }; // A table of composite reference rows. Rows maintain pointers to the @@ -347,24 +347,10 @@ struct CompositeReferenceRow { // We don't put the shared_ptr's into the rows for efficiency reasons. template class CompositeReferenceTable { - // Contains shared_ptr refs for all RecordBatches referred to by the contents of rows_ - std::unordered_map> _ptr2ref; - - // Row table references - std::vector> rows_; - - // Total number of tables in the composite table - size_t _n_tables; - - // Adds a RecordBatch ref to the mapping, if needed - void add_record_batch_ref(const std::shared_ptr& ref) { - if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; - } - public: - explicit CompositeReferenceTable(size_t n_tables) : _n_tables(n_tables) { - assert(_n_tables >= 1); - assert(_n_tables <= MAX_TABLES); + explicit CompositeReferenceTable(size_t n_tables) : n_tables_(n_tables) { + assert(n_tables_ >= 1); + assert(n_tables_ <= MAX_TABLES); } size_t n_rows() const { return rows_.size(); } @@ -373,7 +359,7 @@ class CompositeReferenceTable { // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key void emplace(std::vector>& in, int64_t tolerance) { - assert(in.size() == _n_tables); + assert(in.size() == n_tables_); // Get the LHS key KeyType key = in[0]->get_latest_key(); @@ -393,13 +379,11 @@ class CompositeReferenceTable { row_index_t new_batch_size = lhs_latest_batch->num_rows(); row_index_t new_capacity = rows_.size() + new_batch_size; if (rows_.capacity() < new_capacity) rows_.reserve(new_capacity); - // cerr << "new_batch_size=" << new_batch_size << " old_size=" << rows_.size() << " - // new_capacity=" << rows_.capacity() << endl; } rows_.resize(rows_.size() + 1); auto& row = rows_.back(); - row._refs[0]._batch = lhs_latest_batch.get(); - row._refs[0]._row = lhs_latest_row; + row.refs[0].batch = lhs_latest_batch.get(); + row.refs[0].row = lhs_latest_row; add_record_batch_ref(lhs_latest_batch); // Get the state for that key from all on the RHS -- assumes it's up to date @@ -412,52 +396,28 @@ class CompositeReferenceTable { if ((*opt_entry)->_time + tolerance >= lhs_latest_time) { // Have a valid entry const MemoStore::Entry* entry = *opt_entry; - row._refs[i]._batch = entry->_batch.get(); - row._refs[i]._row = entry->_row; + row.refs[i].batch = entry->_batch.get(); + row.refs[i].row = entry->_row; add_record_batch_ref(entry->_batch); continue; } } - row._refs[i]._batch = NULL; - row._refs[i]._row = 0; + row.refs[i].batch = NULL; + row.refs[i].row = 0; } } - private: - template - std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { - Builder builder; - // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask - builder.AppendEmptyValues(rows_.size()); - for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { - const auto& ref = rows_[i_row]._refs[i_table]; - if (ref._batch) - builder[i_row] = - ref._batch->column_data(i_col)->template GetValues( - 1)[ref._row]; - // TODO: set null value if ref._batch is null -- currently we don't due to API - // limitations of the builders. - } - std::shared_ptr result; - if (!builder.Finish(&result).ok()) { - std::cerr << "Error when creating Arrow array from builder\n"; - exit(-1); // TODO: better error handling - } - return result; - } - - public: // Materializes the current reference table into a target record batch std::shared_ptr materialize( const std::shared_ptr& output_schema, const std::vector>& state) { // cerr << "materialize BEGIN\n"; - assert(state.size() == _n_tables); + assert(state.size() == n_tables_); assert(state.size() >= 1); // Don't build empty batches size_t n_rows = rows_.size(); - if (!n_rows) return nullptr; + if (!n_rows) return NULLPTR; // Count output columns (dbg sanitycheck) { @@ -473,9 +433,9 @@ class CompositeReferenceTable { std::shared_ptr i64_type = arrow::int64(); std::shared_ptr f64_type = arrow::float64(); - // Build the arrays column-by-column from our rows + // Build the arrays column-by-column from the rows std::vector> arrays(output_schema->num_fields()); - for (size_t i_table = 0; i_table < _n_tables; ++i_table) { + for (size_t i_table = 0; i_table < n_tables_; ++i_table) { int n_src_cols = state.at(i_table)->get_schema()->num_fields(); { for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { @@ -502,7 +462,7 @@ class CompositeReferenceTable { i_src_col); } else { std::cerr << "Unsupported data type: " << field_type->name() << "\n"; - exit(-1); // TODO: validate elsewhere for better error handling + exit(-1); // TODO: validate elsewhere for better error handling } } } @@ -519,6 +479,44 @@ class CompositeReferenceTable { // Returns true if there are no rows bool empty() const { return rows_.empty(); } + + private: + // Contains shared_ptr refs for all RecordBatches referred to by the contents of rows_ + std::unordered_map> _ptr2ref; + + // Row table references + std::vector> rows_; + + // Total number of tables in the composite table + size_t n_tables_; + + // Adds a RecordBatch ref to the mapping, if needed + void add_record_batch_ref(const std::shared_ptr& ref) { + if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; + } + + template + std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { + Builder builder; + // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask + builder.AppendEmptyValues(rows_.size()); + for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { + const auto& ref = rows_[i_row].refs[i_table]; + if (ref.batch) { + builder[i_row] = + ref.batch->column_data(i_col)->template GetValues( + 1)[ref.row]; + } + // TODO: set null value if ref.batch is null -- currently we don't due to API + // limitations of the builders. + } + std::shared_ptr result; + if (!builder.Finish(&result).ok()) { + std::cerr << "Error when creating Arrow array from builder\n"; + exit(-1); // TODO: better error handling + } + return result; + } }; class AsofJoinNode : public ExecNode { @@ -594,7 +592,7 @@ class AsofJoinNode : public ExecNode { // Emit the batch std::shared_ptr r = - dst.empty() ? nullptr : dst.materialize(output_schema(), _state); + dst.empty() ? NULLPTR : dst.materialize(output_schema(), _state); return r; } From 26bc8623f6430ca1901a3f81a53c7f138294718b Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 May 2022 17:18:22 -0400 Subject: [PATCH 12/47] Fix nulls in test result --- cpp/src/arrow/compute/exec/asof_join_node.cc | 13 +++++++------ cpp/src/arrow/compute/exec/asof_join_node_test.cc | 12 ++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index f9c558d1067..a3b7444e3f9 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -499,16 +499,17 @@ class CompositeReferenceTable { std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { Builder builder; // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask - builder.AppendEmptyValues(rows_.size()); + // builder.AppendEmptyValues(rows_.size()); + builder.Reserve(rows_.size()); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { - builder[i_row] = - ref.batch->column_data(i_col)->template GetValues( - 1)[ref.row]; + builder.UnsafeAppend( + ref.batch->column_data(i_col)->template GetValues(1)[ref.row] + ); + } else { + builder.AppendNull(); } - // TODO: set null value if ref.batch is null -- currently we don't due to API - // limitations of the builders. } std::shared_ptr result; if (!builder.Finish(&result).ok()) { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index ee6a6da9567..e2d0b1c83ba 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -125,7 +125,7 @@ void RunNonEmptyTest(bool exact_matches) { r0_batches = GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])"}); r1_batches = GenerateBatchesFromString(r1_schema, {R"([[1000, 1, 101.0]])"}); exp_batches = GenerateBatchesFromString( - exp_schema, {R"([[0, 1, 1.0, 11.0, 0.0], [1000, 1, 2.0, 11.0, 101.0]])"}); + exp_schema, {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); // Single key, multiple batches @@ -165,7 +165,7 @@ void RunNonEmptyTest(bool exact_matches) { R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); exp_batches = GenerateBatchesFromString( exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); @@ -182,11 +182,11 @@ void RunNonEmptyTest(bool exact_matches) { R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); exp_batches = GenerateBatchesFromString( exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 0.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 500); - // Multi key, multiple batches, misaligned batches, zero tolerance + // // Multi key, multiple batches, misaligned batches, zero tolerance l_batches = GenerateBatchesFromString( l_schema, {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", @@ -199,8 +199,8 @@ void RunNonEmptyTest(bool exact_matches) { R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); exp_batches = GenerateBatchesFromString( exp_schema, - {R"([[0, 1, 1.0, 11.0, 0.0], [0, 2, 21.0, 0.0, 1001.0], [500, 1, 2.0, 0.0, 101.0], [1000, 2, 22.0, 0.0, 0.0], [1500, 1, 3.0, 0.0, 0.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 0.0, 0.0]])"}); + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 0); } From 5a6afbd5082b5f8175c6137d084bd7310ea80553 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 May 2022 17:43:16 -0400 Subject: [PATCH 13/47] Clean up includes --- cpp/src/arrow/compute/exec/asof_join.h | 15 -------- cpp/src/arrow/compute/exec/asof_join_node.cc | 38 ++++++-------------- 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index 54b1f665c78..f90930bbee2 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -17,23 +17,10 @@ #pragma once -#include -#include -#include - #include #include -#include -#include -#include // so we don't need to require C++20 -#include #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" -#include "arrow/compute/exec/task_util.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/util/tracing_internal.h" namespace arrow { namespace compute { @@ -42,8 +29,6 @@ typedef int32_t KeyType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 -// The max rows per batch is dictated by the data type for row index -#define MAX_ROWS_PER_BATCH 0xFFFFFFFF typedef uint32_t row_index_t; typedef int col_index_t; diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index a3b7444e3f9..03d6f14f269 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -16,8 +16,10 @@ // under the License. #include -#include +#include +#include +#include #include "arrow/compute/exec/asof_join.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" @@ -25,29 +27,13 @@ #include "arrow/compute/exec/util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" +#include #include "arrow/util/make_unique.h" -#include "arrow/util/thread_pool.h" -#include -#include -#include -#include -#include -//#include -#include -#include -#include -#include -#include // so we don't need to require C++20 -#include -#include -#include -#include #include -#include +#include #include -#include namespace arrow { namespace compute { @@ -56,7 +42,7 @@ namespace compute { * Simple implementation for an unbound concurrent queue */ template -class concurrent_queue { +class ConcurrentQueue { public: T pop() { std::unique_lock lock(mutex_); @@ -295,7 +281,7 @@ class InputState { private: // Pending record batches. The latest is the front. Batches cannot be empty. - concurrent_queue> queue_; + ConcurrentQueue> queue_; // Wildcard key for this input, if applicable. util::optional wildcard_key_; @@ -374,8 +360,6 @@ class CompositeReferenceTable { if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. - assert(lhs_latest_batch->num_rows() <= - MAX_ROWS_PER_BATCH); // TODO: better error handling row_index_t new_batch_size = lhs_latest_batch->num_rows(); row_index_t new_capacity = rows_.size() + new_batch_size; if (rows_.capacity() < new_capacity) rows_.reserve(new_capacity); @@ -739,11 +723,11 @@ class AsofJoinNode : public ExecNode { finished_.MarkFinished(); for (auto&& input : inputs_) input->StopProducing(this); } - Future<> finished() override { return finished_; } + arrow::Future<> finished() override { return finished_; } private: std::unique_ptr schema_mgr_; - Future<> finished_; + arrow::Future<> finished_; // InputStates // Each input state correponds to an input table // @@ -753,7 +737,7 @@ class AsofJoinNode : public ExecNode { // Queue for triggering processing of a given input // (a false value is a poison pill) - concurrent_queue _process; + ConcurrentQueue _process; // Worker thread std::thread _process_thread; @@ -806,7 +790,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, for (auto& state : _state) dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); - finished_ = Future<>::MakeFinished(); + finished_ = arrow::Future<>::MakeFinished(); } namespace internal { From 4f7cac72f73e23f69555d8974422637d59086579 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 12 May 2022 18:18:00 -0400 Subject: [PATCH 14/47] Clean up error handling --- cpp/src/arrow/compute/exec/asof_join_node.cc | 91 +++++++++++--------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 03d6f14f269..22cbc55c095 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -20,21 +20,22 @@ #include #include +#include #include "arrow/compute/exec/asof_join.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/result.h" +#include "arrow/status.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" -#include #include "arrow/util/make_unique.h" -#include #include +#include #include - namespace arrow { namespace compute { @@ -392,7 +393,7 @@ class CompositeReferenceTable { } // Materializes the current reference table into a target record batch - std::shared_ptr materialize( + Result> materialize( const std::shared_ptr& output_schema, const std::vector>& state) { // cerr << "materialize BEGIN\n"; @@ -412,11 +413,6 @@ class CompositeReferenceTable { assert(n_out_cols == output_schema->num_fields()); } - // Instance the types we support - std::shared_ptr i32_type = arrow::int32(); - std::shared_ptr i64_type = arrow::int64(); - std::shared_ptr f64_type = arrow::float64(); - // Build the arrays column-by-column from the rows std::vector> arrays(output_schema->num_fields()); for (size_t i_table = 0; i_table < n_tables_; ++i_table) { @@ -432,21 +428,25 @@ class CompositeReferenceTable { assert(src_field->type()->Equals(dst_field->type())); assert(src_field->name() == dst_field->name()); const auto& field_type = src_field->type(); - if (field_type->Equals(i32_type)) { - arrays.at(i_dst_col) = - materialize_primitive_column(i_table, - i_src_col); - } else if (field_type->Equals(i64_type)) { - arrays.at(i_dst_col) = - materialize_primitive_column(i_table, - i_src_col); - } else if (field_type->Equals(f64_type)) { - arrays.at(i_dst_col) = - materialize_primitive_column(i_table, - i_src_col); + + if (field_type->Equals(arrow::int32())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (materialize_primitive_column(i_table, + i_src_col))); + } else if (field_type->Equals(arrow::int64())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (materialize_primitive_column(i_table, + i_src_col))); + } else if (field_type->Equals(arrow::float64())) { + ARROW_ASSIGN_OR_RAISE( + arrays.at(i_dst_col), + (materialize_primitive_column(i_table, + i_src_col))); } else { - std::cerr << "Unsupported data type: " << field_type->name() << "\n"; - exit(-1); // TODO: validate elsewhere for better error handling + ARROW_RETURN_NOT_OK( + Status::Invalid("Unsupported data type: ", src_field->name())); } } } @@ -454,10 +454,10 @@ class CompositeReferenceTable { // Build the result assert(sizeof(size_t) >= sizeof(int64_t)); // Make takes signed int64_t for num_rows + // TODO: check n_rows for cast std::shared_ptr r = arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); - // cerr << "materialize END (ndstrows="<< (r?r->num_rows():-1) <<")\n"; return r; } @@ -480,26 +480,21 @@ class CompositeReferenceTable { } template - std::shared_ptr materialize_primitive_column(size_t i_table, size_t i_col) { + Result> materialize_primitive_column(size_t i_table, + size_t i_col) { Builder builder; - // builder.Resize(_rows.size()); // <-- can't just do this -- need to set the bitmask - // builder.AppendEmptyValues(rows_.size()); builder.Reserve(rows_.size()); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { builder.UnsafeAppend( - ref.batch->column_data(i_col)->template GetValues(1)[ref.row] - ); + ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); } else { - builder.AppendNull(); + builder.AppendNull(); } } std::shared_ptr result; - if (!builder.Finish(&result).ok()) { - std::cerr << "Error when creating Arrow array from builder\n"; - exit(-1); // TODO: better error handling - } + ARROW_RETURN_NOT_OK(builder.Finish(&result)); return result; } }; @@ -542,7 +537,7 @@ class AsofJoinNode : public ExecNode { return true; } - std::shared_ptr process_inner() { + Result> process_inner() { assert(!_state.empty()); auto& lhs = *_state.at(0); @@ -576,9 +571,11 @@ class AsofJoinNode : public ExecNode { } // Emit the batch - std::shared_ptr r = - dst.empty() ? NULLPTR : dst.materialize(output_schema(), _state); - return r; + if (dst.empty()) { + return NULLPTR; + } else { + return dst.materialize(output_schema(), _state); + } } void process() { @@ -592,11 +589,19 @@ class AsofJoinNode : public ExecNode { // Process batches while we have data for (;;) { - std::shared_ptr out_rb = process_inner(); - if (!out_rb) break; - ++_progress_batches_produced; - ExecBatch out_b(*out_rb); - outputs_[0]->InputReceived(this, std::move(out_b)); + Result> result = process_inner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ++_progress_batches_produced; + ExecBatch out_b(*out_rb); + outputs_[0]->InputReceived(this, std::move(out_b)); + } else { + // TODO: Proper error handling + outputs_[0]->ErrorReceived(this, result.status()); + // StopProducing(); + } } std::cerr << "process() end\n"; From 8773317c6540c5993dc337e962a7d48aa8d04a53 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 12 May 2022 18:29:36 -0400 Subject: [PATCH 15/47] Error handling --- cpp/src/arrow/compute/exec/asof_join_node.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 22cbc55c095..0512117a941 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -598,9 +598,9 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - // TODO: Proper error handling - outputs_[0]->ErrorReceived(this, result.status()); - // StopProducing(); + StopProducing(); + ErrorIfNotOk(result.status()); + return; } } From 22c99414057a8f1ef924fb7880d7677952d60877 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 17 May 2022 10:28:53 -0400 Subject: [PATCH 16/47] Fix compiler warning --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 0512117a941..2df7564eceb 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -708,11 +708,11 @@ class AsofJoinNode : public ExecNode { finished_ = arrow::Future<>::Make(); return Status::OK(); } - void PauseProducing(ExecNode* output) override { + void PauseProducing(ExecNode* output, int32_t counter) override { std::cout << "PauseProducing" << "\n"; } - void ResumeProducing(ExecNode* output) override { + void ResumeProducing(ExecNode* output, int32_t counter) override { std::cout << "ResumeProducing" << "\n"; } From 6b27e6b3237af039add6f5dde89b5cc8c68bb764 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 23 May 2022 14:33:58 -0400 Subject: [PATCH 17/47] Fix Wshorten-64-to-32 error --- cpp/src/arrow/compute/exec/asof_join.h | 2 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index f90930bbee2..988c4984315 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -29,7 +29,7 @@ typedef int32_t KeyType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 -typedef uint32_t row_index_t; +typedef uint64_t row_index_t; typedef int col_index_t; class AsofJoinSchema { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2df7564eceb..ce9509e61bb 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -152,7 +152,7 @@ class InputState { schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name key_col_index_(schema->GetFieldIndex(key_col_name)) {} - size_t init_src_to_dst_mapping(size_t dst_offset, bool skip_time_and_key_fields) { + size_t init_src_to_dst_mapping(int dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) if (!(skip_time_and_key_fields && is_time_or_key_column(i))) @@ -206,7 +206,7 @@ class InputState { (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.empty(); if (have_active_batch) { // If we have an active batch - if (++latest_ref_row_ >= queue_.unsync_front()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t) queue_.unsync_front()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; @@ -419,7 +419,7 @@ class CompositeReferenceTable { int n_src_cols = state.at(i_table)->get_schema()->num_fields(); { for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { - util::optional i_dst_col_opt = + util::optional i_dst_col_opt = state[i_table]->map_src_to_dst(i_src_col); if (!i_dst_col_opt) continue; col_index_t i_dst_col = *i_dst_col_opt; @@ -481,7 +481,7 @@ class CompositeReferenceTable { template Result> materialize_primitive_column(size_t i_table, - size_t i_col) { + col_index_t i_col) { Builder builder; builder.Reserve(rows_.size()); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { @@ -599,7 +599,7 @@ class AsofJoinNode : public ExecNode { outputs_[0]->InputReceived(this, std::move(out_b)); } else { StopProducing(); - ErrorIfNotOk(result.status()); + ErrorIfNotOk(result.status()); return; } } @@ -791,7 +791,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, _state.push_back(::arrow::internal::make_unique( inputs[i]->output_schema(), *_options.time.name(), *_options.keys.name(), util::make_optional(0) /*TODO: make wildcard configuirable*/)); - size_t dst_offset = 0; + int dst_offset = 0; for (auto& state : _state) dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); From 775be1d492a2611cd09a979df4f8a8446e7cb6ff Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 23 May 2022 14:58:11 -0400 Subject: [PATCH 18/47] Fix lint --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index ce9509e61bb..07571dde1d5 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -206,7 +206,7 @@ class InputState { (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.empty(); if (have_active_batch) { // If we have an active batch - if (++latest_ref_row_ >= (row_index_t) queue_.unsync_front()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t)queue_.unsync_front()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; @@ -599,7 +599,7 @@ class AsofJoinNode : public ExecNode { outputs_[0]->InputReceived(this, std::move(out_b)); } else { StopProducing(); - ErrorIfNotOk(result.status()); + ErrorIfNotOk(result.status()); return; } } From 2dc5691f36fcd839e8e676a5a9df90c7a2ae37cc Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 23 May 2022 15:35:41 -0400 Subject: [PATCH 19/47] Fix lint --- cpp/src/arrow/compute/exec/asof_join_node.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 07571dde1d5..a22509ac4ba 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -152,7 +152,7 @@ class InputState { schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name key_col_index_(schema->GetFieldIndex(key_col_name)) {} - size_t init_src_to_dst_mapping(int dst_offset, bool skip_time_and_key_fields) { + col_index_t init_src_to_dst_mapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) if (!(skip_time_and_key_fields && is_time_or_key_column(i))) @@ -598,7 +598,7 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - StopProducing(); + StopProducing(); ErrorIfNotOk(result.status()); return; } @@ -791,7 +791,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, _state.push_back(::arrow::internal::make_unique( inputs[i]->output_schema(), *_options.time.name(), *_options.keys.name(), util::make_optional(0) /*TODO: make wildcard configuirable*/)); - int dst_offset = 0; + col_index_t dst_offset = 0; for (auto& state : _state) dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); From a9dd980305e09514ba33b4a54a23f00df444ea55 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 24 May 2022 15:46:46 -0400 Subject: [PATCH 20/47] Fix compiler warning Wunused-result --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index a22509ac4ba..a8958c5465f 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -483,14 +483,14 @@ class CompositeReferenceTable { Result> materialize_primitive_column(size_t i_table, col_index_t i_col) { Builder builder; - builder.Reserve(rows_.size()); + ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { builder.UnsafeAppend( ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); } else { - builder.AppendNull(); + builder.UnsafeAppendNull(); } } std::shared_ptr result; From 0f39fce13b1c2c769c53ba1cba7f8f3107f0d0d7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 24 May 2022 16:28:04 -0400 Subject: [PATCH 21/47] Fix format --- cpp/src/arrow/compute/exec/asof_join_node.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index a8958c5465f..d7f376eef6d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -152,7 +152,8 @@ class InputState { schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name key_col_index_(schema->GetFieldIndex(key_col_name)) {} - col_index_t init_src_to_dst_mapping(col_index_t dst_offset, bool skip_time_and_key_fields) { + col_index_t init_src_to_dst_mapping(col_index_t dst_offset, + bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) if (!(skip_time_and_key_fields && is_time_or_key_column(i))) @@ -598,9 +599,9 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - StopProducing(); - ErrorIfNotOk(result.status()); - return; + StopProducing(); + ErrorIfNotOk(result.status()); + return; } } From 761e5dec84a1204d94a9c81a7e8a93e31410e872 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 24 May 2022 16:57:24 -0400 Subject: [PATCH 22/47] Remove debug statement --- cpp/src/arrow/compute/exec/asof_join_node.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index d7f376eef6d..9cdc0fb3acc 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -405,15 +405,6 @@ class CompositeReferenceTable { size_t n_rows = rows_.size(); if (!n_rows) return NULLPTR; - // Count output columns (dbg sanitycheck) - { - int n_out_cols = 0; - for (const auto& s : state) n_out_cols += s->get_schema()->num_fields(); - n_out_cols -= - (state.size() - 1) * 2; // remove column indices for key and time cols on RHS - assert(n_out_cols == output_schema->num_fields()); - } - // Build the arrays column-by-column from the rows std::vector> arrays(output_schema->num_fields()); for (size_t i_table = 0; i_table < n_tables_; ++i_table) { From 0387e5c1d5cfff28faf0066a44a128124c452d90 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 26 May 2022 09:48:22 -0400 Subject: [PATCH 23/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 9cdc0fb3acc..d5de8bd1235 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -398,8 +398,8 @@ class CompositeReferenceTable { const std::shared_ptr& output_schema, const std::vector>& state) { // cerr << "materialize BEGIN\n"; - assert(state.size() == n_tables_); - assert(state.size() >= 1); + DCHECK_EQ(state.size(), n_tables_); + DCHECK_GE(state.size(), 1); // Don't build empty batches size_t n_rows = rows_.size(); From 15ba43dfb0e52a08b2a06d297395c763cc12a342 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 26 May 2022 09:48:29 -0400 Subject: [PATCH 24/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index d5de8bd1235..131f121faa9 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -748,7 +748,7 @@ class AsofJoinNode : public ExecNode { std::shared_ptr AsofJoinSchema::MakeOutputSchema( const std::vector& inputs, const AsofJoinNodeOptions& options) { std::vector> fields; - assert(inputs.size() > 1); + DCHECK_GT(inputs.size(), 1); // Directly map LHS fields for (int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) From b92a3030e66f3d4a25690164952229e82765a2d6 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 26 May 2022 09:48:46 -0400 Subject: [PATCH 25/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 131f121faa9..83dd5c3455f 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -45,7 +45,7 @@ namespace compute { template class ConcurrentQueue { public: - T pop() { + T Pop() { std::unique_lock lock(mutex_); cond_.wait(lock, [&] { return !queue_.empty(); }); auto item = queue_.front(); From f0edd177b9d7e11374c6549f28031b441d1427f6 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 26 May 2022 09:48:53 -0400 Subject: [PATCH 26/47] Update cpp/src/arrow/compute/exec/asof_join.h Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h index 988c4984315..15f9e21c6b2 100644 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ b/cpp/src/arrow/compute/exec/asof_join.h @@ -28,7 +28,7 @@ namespace compute { typedef int32_t KeyType; // Maximum number of tables that can be joined -#define MAX_JOIN_TABLES 64 +constexpr int kMaxJoinTables = 64 typedef uint64_t row_index_t; typedef int col_index_t; From 58f229d4e8ba19b270d836d2eca89a5dd54462ae Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 26 May 2022 10:08:00 -0400 Subject: [PATCH 27/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 83dd5c3455f..ef0fc55d24c 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -606,7 +606,7 @@ class AsofJoinNode : public ExecNode { if (_state.at(0)->finished()) { total_batches_produced_ = util::make_optional(_progress_batches_produced); StopProducing(); - assert(total_batches_produced_.has_value()); + DCHECK(total_batches_produced_.has_value()); outputs_[0]->InputFinished(this, *total_batches_produced_); } } From 15e27830e23f7dfa86a55922ea92e9f5255bf1c5 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 14:33:10 -0400 Subject: [PATCH 28/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index ef0fc55d24c..e831a07c71f 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -53,7 +53,7 @@ class ConcurrentQueue { return item; } - void push(const T& item) { + void Push(const T& item) { std::unique_lock lock(mutex_); queue_.push(item); cond_.notify_one(); From 7aa252a121ed4a8c36c3543b0f1d4956b7277e9c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 15:28:43 -0400 Subject: [PATCH 29/47] Update cpp/src/arrow/compute/exec/asof_join_node.cc Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index e831a07c71f..ea0b25eca57 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -337,8 +337,8 @@ template class CompositeReferenceTable { public: explicit CompositeReferenceTable(size_t n_tables) : n_tables_(n_tables) { - assert(n_tables_ >= 1); - assert(n_tables_ <= MAX_TABLES); + DCHECK_GE(n_tables_, 1); + DCHECK_LE(n_tables_, MAX_TABLES); } size_t n_rows() const { return rows_.size(); } From 9c332ebaaabdee9068b64c0c41b93d4d3d6ca213 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 15:34:12 -0400 Subject: [PATCH 30/47] Apply suggestions from code review Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/asof_join_node.cc | 33 +++++++++----------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index ea0b25eca57..cbd90f7af3e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -59,7 +59,7 @@ class ConcurrentQueue { cond_.notify_one(); } - util::optional try_pop() { + util::optional TryPop() { // Try to pop the oldest value from the queue (or return nullopt if none) std::unique_lock lock(mutex_); if (queue_.empty()) { @@ -71,7 +71,7 @@ class ConcurrentQueue { } } - bool empty() const { + bool Empty() const { std::unique_lock lock(mutex_); return queue_.empty(); } @@ -105,7 +105,7 @@ struct MemoStore { std::unordered_map _entries; - void store(const std::shared_ptr& batch, row_index_t row, int64_t time, + void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, KeyType key) { auto& e = _entries[key]; // that we can do this assignment optionally, is why we @@ -116,13 +116,13 @@ struct MemoStore { e._time = time; } - util::optional get_entry_for_key(KeyType key) const { + util::optional GetEntryForKey(KeyType key) const { auto e = _entries.find(key); if (_entries.end() == e) return util::nullopt; return util::optional(&e->second); } - void remove_entries_with_lesser_time(int64_t ts) { + void RemoveEntriesWithLesserTime(int64_t ts) { size_t dbg_size0 = _entries.size(); for (auto e = _entries.begin(); e != _entries.end();) if (e->second._time < ts) @@ -166,7 +166,7 @@ class InputState { } bool is_time_or_key_column(col_index_t i) const { - assert(i < schema_->num_fields()); + DCHECK_LT(i, schema_->num_fields()); return (i == time_col_index_) || (i == key_col_index_); } @@ -203,9 +203,7 @@ class InputState { bool advance() { // Returns true if able to advance, false if not. - bool have_active_batch = - (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.empty(); - if (have_active_batch) { + if (!empty()) { // If we have an active batch if (++latest_ref_row_ >= (row_index_t)queue_.unsync_front()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. @@ -213,7 +211,7 @@ class InputState { latest_ref_row_ = 0; have_active_batch &= !queue_.try_pop(); if (have_active_batch) - assert(queue_.unsync_front()->num_rows() > 0); // empty batches disallowed + DCHECK_GT(queue_.unsync_front()->num_rows(), 0); // empty batches disallowed } } return have_active_batch; @@ -266,7 +264,7 @@ class InputState { util::optional get_memo_time_for_key(KeyType key) { auto r = get_memo_entry_for_key(key); - return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; + return r.has_value() ? (*r)->_time : util::nullopt; } void remove_memo_entries_with_lesser_time(int64_t ts) { @@ -347,14 +345,14 @@ class CompositeReferenceTable { // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key void emplace(std::vector>& in, int64_t tolerance) { - assert(in.size() == n_tables_); + DCHECK_EQ(in.size(), n_tables_); // Get the LHS key KeyType key = in[0]->get_latest_key(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) - assert(!in[0]->empty()); + DCHECK(!in[0]->empty()); const std::shared_ptr& lhs_latest_batch = in[0]->get_latest_batch(); row_index_t lhs_latest_row = in[0]->get_latest_row(); @@ -378,7 +376,7 @@ class CompositeReferenceTable { util::optional opt_entry = in[i]->get_memo_entry_for_key(key); if (opt_entry.has_value()) { - assert(*opt_entry); + DCHECK(*opt_entry); if ((*opt_entry)->_time + tolerance >= lhs_latest_time) { // Have a valid entry const MemoStore::Entry* entry = *opt_entry; @@ -417,8 +415,8 @@ class CompositeReferenceTable { col_index_t i_dst_col = *i_dst_col_opt; const auto& src_field = state[i_table]->get_schema()->field(i_src_col); const auto& dst_field = output_schema->field(i_dst_col); - assert(src_field->type()->Equals(dst_field->type())); - assert(src_field->name() == dst_field->name()); + DCHECK(src_field->type()->Equals(dst_field->type())); + DCHECK_EQ(src_field->name(), dst_field->name()); const auto& field_type = src_field->type(); if (field_type->Equals(arrow::int32())) { @@ -445,7 +443,7 @@ class CompositeReferenceTable { } // Build the result - assert(sizeof(size_t) >= sizeof(int64_t)); // Make takes signed int64_t for num_rows + DCHECK_GE(sizeof(size_t), sizeof(int64_t)) << "AsofJoinNode requires size_t >= 8 bytes"; // TODO: check n_rows for cast std::shared_ptr r = @@ -718,7 +716,6 @@ class AsofJoinNode : public ExecNode { std::cerr << "StopProducing" << std::endl; // if(batch_count_.Cancel()) finished_.MarkFinished(); finished_.MarkFinished(); - for (auto&& input : inputs_) input->StopProducing(this); } arrow::Future<> finished() override { return finished_; } From cae55926ce3f43f028cf2a97375d13fe3d56d211 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 15:16:53 -0400 Subject: [PATCH 31/47] Address PR comments --- cpp/src/arrow/compute/exec/asof_join_node.cc | 116 +++++++++---------- cpp/src/arrow/compute/exec/options.h | 27 +++-- 2 files changed, 76 insertions(+), 67 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index cbd90f7af3e..4b9f1c50101 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -21,7 +21,6 @@ #include #include #include -#include "arrow/compute/exec/asof_join.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" @@ -39,6 +38,14 @@ namespace arrow { namespace compute { +// Remove this when multiple keys and/or types is supported +typedef int32_t KeyType; + +// Maximum number of tables that can be joined +#define MAX_JOIN_TABLES 64 +typedef uint64_t row_index_t; +typedef int col_index_t; + /** * Simple implementation for an unbound concurrent queue */ @@ -143,10 +150,8 @@ class InputState { public: InputState(const std::shared_ptr& schema, - const std::string& time_col_name, const std::string& key_col_name, - util::optional wildcard_key) + const std::string& time_col_name, const std::string& key_col_name) : queue_(), - wildcard_key_(wildcard_key), schema_(schema), time_col_index_( schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name @@ -180,7 +185,6 @@ class InputState { return queue_.empty(); } - int countbatches_processed_() const { return batches_processed_; } int count_total_batches() const { return total_batches_; } // Gets latest batch (precondition: must not be empty) @@ -249,7 +253,7 @@ class InputState { void push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { - queue_.push(rb); + queue_.Push(rb); } else { ++batches_processed_; // don't enqueue empty batches, just record as processed } @@ -258,7 +262,6 @@ class InputState { util::optional get_memo_entry_for_key(KeyType key) { auto r = memo_.get_entry_for_key(key); if (r.has_value()) return r; - if (wildcard_key_.has_value()) r = memo_.get_entry_for_key(*wildcard_key_); return r; } @@ -283,9 +286,6 @@ class InputState { // Pending record batches. The latest is the front. Batches cannot be empty. ConcurrentQueue> queue_; - // Wildcard key for this input, if applicable. - util::optional wildcard_key_; - // Schema associated with the input std::shared_ptr schema_; @@ -489,6 +489,12 @@ class CompositeReferenceTable { } }; +class AsofJoinSchema { + public: + std::shared_ptr MakeOutputSchema(const std::vector& inputs, + const AsofJoinNodeOptions& options); +}; + class AsofJoinNode : public ExecNode { // Constructs labels for inputs static std::vector build_input_labels( @@ -500,21 +506,21 @@ class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp bool update_rhs() { - auto& lhs = *_state.at(0); + auto& lhs = *state_.at(0); auto lhs_latest_time = lhs.get_latest_time(); bool any_updated = false; - for (size_t i = 1; i < _state.size(); ++i) - any_updated |= _state[i]->advance_and_memoize(lhs_latest_time); + for (size_t i = 1; i < state_.size(); ++i) + any_updated |= state_[i]->advance_and_memoize(lhs_latest_time); return any_updated; } // Returns false if RHS not up to date for LHS bool is_up_to_date_for_lhs_row() const { - auto& lhs = *_state[0]; + auto& lhs = *state_[0]; if (lhs.empty()) return false; // can't proceed if nothing on the LHS int64_t lhs_ts = lhs.get_latest_time(); - for (size_t i = 1; i < _state.size(); ++i) { - auto& rhs = *_state[i]; + for (size_t i = 1; i < state_.size(); ++i) { + auto& rhs = *state_[i]; if (!rhs.finished()) { // If RHS is finished, then we know it's up to date (but if it isn't, it might be // up to date) @@ -528,11 +534,12 @@ class AsofJoinNode : public ExecNode { } Result> process_inner() { - assert(!_state.empty()); - auto& lhs = *_state.at(0); + + assert(!state_.empty()); + auto& lhs = *state_.at(0); // Construct new target table if needed - CompositeReferenceTable dst(_state.size()); + CompositeReferenceTable dst(state_.size()); // Generate rows into the dst table until we either run out of data or hit the row // limit, or run out of input @@ -545,18 +552,18 @@ class AsofJoinNode : public ExecNode { // Only update if we have up-to-date information for the LHS row if (is_up_to_date_for_lhs_row()) { - dst.emplace(_state, _options.tolerance); + dst.emplace(state_, options_.tolerance); if (!lhs.advance()) break; // if we can't advance LHS, we're done for this batch } else { - if ((!any_advanced) && (_state.size() > 1)) break; // need to wait for new data + if ((!any_advanced) && (state_.size() > 1)) break; // need to wait for new data } } // Prune memo entries that have expired (to bound memory consumption) if (!lhs.empty()) { - for (size_t i = 1; i < _state.size(); ++i) { - _state[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - - _options.tolerance); + for (size_t i = 1; i < state_.size(); ++i) { + state_[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - + options_.tolerance); } } @@ -564,14 +571,14 @@ class AsofJoinNode : public ExecNode { if (dst.empty()) { return NULLPTR; } else { - return dst.materialize(output_schema(), _state); + return dst.materialize(output_schema(), state_); } } void process() { std::cerr << "process() begin\n"; - std::lock_guard guard(_gate); + std::lock_guard guard(gate_); if (finished_.is_finished()) { std::cerr << "InputReceived EARLYEND\n"; return; @@ -584,7 +591,7 @@ class AsofJoinNode : public ExecNode { if (result.ok()) { auto out_rb = *result; if (!out_rb) break; - ++_progress_batches_produced; + ++batches_produced_; ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { @@ -601,18 +608,16 @@ class AsofJoinNode : public ExecNode { // // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) - if (_state.at(0)->finished()) { - total_batches_produced_ = util::make_optional(_progress_batches_produced); + if (state_.at(0)->finished()) { StopProducing(); - DCHECK(total_batches_produced_.has_value()); - outputs_[0]->InputFinished(this, *total_batches_produced_); + outputs_[0]->InputFinished(this, batches_produced_); } } void process_thread() { std::cerr << "AsofJoinNode::process_thread started.\n"; for (;;) { - if (!_process.pop()) { + if (!process_.Pop()) { std::cerr << "AsofJoinNode::process_thread done.\n"; return; } @@ -629,8 +634,8 @@ class AsofJoinNode : public ExecNode { std::unique_ptr schema_mgr); virtual ~AsofJoinNode() { - _process.push(false); // poison pill - _process_thread.join(); + process_.Push(false); // poison pill + process_thread_.join(); } static arrow::Result Make(ExecPlan* plan, std::vector inputs, @@ -664,8 +669,8 @@ class AsofJoinNode : public ExecNode { // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - _state.at(k)->push(rb); - _process.push(true); + state_.at(k)->push(rb); + process_.Push(true); std::cerr << "InputReceived END\n"; } @@ -675,20 +680,18 @@ class AsofJoinNode : public ExecNode { } void InputFinished(ExecNode* input, int total_batches) override { std::cerr << "InputFinished BEGIN\n"; - // bool is_finished=false; { - std::lock_guard guard(_gate); + std::lock_guard guard(gate_); std::cerr << "InputFinished find\n"; ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); - // cerr << "set_total_batches for input " << k << ": " << total_batches << "\n"; - _state.at(k)->set_total_batches(total_batches); + state_.at(k)->set_total_batches(total_batches); } // Trigger a process call // The reason for this is that there are cases at the end of a table where we don't // know whether the RHS of the join is up-to-date until we know that the table is // finished. - _process.push(true); + process_.Push(true); std::cerr << "InputFinished END\n"; } @@ -724,22 +727,18 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished_; // InputStates // Each input state correponds to an input table - // - std::vector> _state; - std::mutex _gate; - AsofJoinNodeOptions _options; + std::vector> state_; + std::mutex gate_; + AsofJoinNodeOptions options_; // Queue for triggering processing of a given input // (a false value is a poison pill) - ConcurrentQueue _process; + ConcurrentQueue process_; // Worker thread - std::thread _process_thread; - - // Total batches produced, once we've finished -- only known at completion time. - util::optional total_batches_produced_; + std::thread process_thread_; // In-progress batches produced - int _progress_batches_produced = 0; + int batches_produced_ = 0; }; std::shared_ptr AsofJoinSchema::MakeOutputSchema( @@ -756,7 +755,7 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema( const auto& input_schema = inputs[j]->output_schema(); for (int i = 0; i < input_schema->num_fields(); ++i) { const auto& name = input_schema->field(i)->name(); - if ((name != *options.keys.name()) && (name != *options.time.name())) { + if ((name != *options.by_key.name()) && (name != *options.on_key.name())) { fields.push_back(input_schema->field(i)); } } @@ -773,15 +772,14 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - _options(join_options), - _process(), - _process_thread(&AsofJoinNode::process_thread_wrapper, this) { + options_(join_options), + process_(), + process_thread_(&AsofJoinNode::process_thread_wrapper, this) { for (size_t i = 0; i < inputs.size(); ++i) - _state.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), *_options.time.name(), *_options.keys.name(), - util::make_optional(0) /*TODO: make wildcard configuirable*/)); + state_.push_back(::arrow::internal::make_unique( + inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); col_index_t dst_offset = 0; - for (auto& state : _state) + for (auto& state : state_) dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); finished_ = arrow::Future<>::MakeFinished(); diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 8283fff353d..532bec58492 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -361,16 +361,27 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { Expression filter; }; + +/// \brief Make a node which implements asof join operation +/// +/// This node takes one left table and (n-1) right tables, and asof joins them +/// together. Batches produced by each inputs must be ordered by the "on" key. +/// The batch size that this node produces is decided by the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef time, FieldRef keys, int64_t tolerance) - : time(std::move(time)), keys(std::move(keys)), tolerance(tolerance) {} - - // time column - FieldRef time; - // keys used for the join. All tables must have the same join key. - FieldRef keys; - // tolerance for the inexact timestamp matching in nanoseconds + AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + + // "on" key for the join. Each input table must be sorted by the "on" key. Inexact + // match is used on the "on" key. i.e., a row is considiered match iff + // left_on - tolerance <= right_on <= left_on. + // Currently, "on" key must be an int64 field + FieldRef on_key; + // "by" key for the join. All tables must have the "by" key. Equity prediciate + // are used on the "by" key. + // Currently, "by" key must be an int32 field + FieldRef by_key; + // tolerance for inexact "on" key matching int64_t tolerance; }; From 8f325004fcb8a46163ae08b9ed69aded29eff91c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 16:23:44 -0400 Subject: [PATCH 32/47] Took another pass of remaing functions to mixed style --- cpp/src/arrow/compute/exec/asof_join_node.cc | 178 +++++++++---------- 1 file changed, 87 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 4b9f1c50101..65fac041c3a 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -87,7 +87,7 @@ class ConcurrentQueue { // For this to be "safe": // 1) the caller logically guarantees that queue is not empty // 2) pop/try_pop cannot be called concurrently with this - const T& unsync_front() const { return queue_.front(); } + const T& UnsyncFront() const { return queue_.front(); } private: std::queue queue_; @@ -157,65 +157,68 @@ class InputState { schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name key_col_index_(schema->GetFieldIndex(key_col_name)) {} - col_index_t init_src_to_dst_mapping(col_index_t dst_offset, + col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) - if (!(skip_time_and_key_fields && is_time_or_key_column(i))) + if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) src_to_dst_[i] = dst_offset++; return dst_offset; } - const util::optional& map_src_to_dst(col_index_t src) const { + const util::optional& MapSrcToDst(col_index_t src) const { return src_to_dst_[src]; } - bool is_time_or_key_column(col_index_t i) const { + bool IsTimeOrKeyColumn(col_index_t i) const { DCHECK_LT(i, schema_->num_fields()); return (i == time_col_index_) || (i == key_col_index_); } // Gets the latest row index, assuming the queue isn't empty - row_index_t get_latest_row() const { return latest_ref_row_; } + row_index_t GetLatestRow() const { return latest_ref_row_; } - bool empty() const { + bool Empty() const { + // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below if (latest_ref_row_ > 0) - return false; // cannot be empty if ref row is >0 -- can avoid slow queue lock - // below - return queue_.empty(); + return false; + return queue_.Empty(); } - int count_total_batches() const { return total_batches_; } + int total_batches() const { return total_batches_; } // Gets latest batch (precondition: must not be empty) - const std::shared_ptr& get_latest_batch() const { - return queue_.unsync_front(); + const std::shared_ptr& GetLatestBatch() const { + return queue_.UnsyncFront(); } - KeyType get_latest_key() const { - return queue_.unsync_front() + KeyType GetLatestKey() const { + return queue_.UnsyncFront() ->column_data(key_col_index_) ->GetValues(1)[latest_ref_row_]; } - int64_t get_latest_time() const { - return queue_.unsync_front() + int64_t GetLatestTime() const { + return queue_.UnsyncFront() ->column_data(time_col_index_) ->GetValues(1)[latest_ref_row_]; } - bool finished() const { return batches_processed_ == total_batches_; } + bool Finished() const { return batches_processed_ == total_batches_; } - bool advance() { + bool Advance() { // Returns true if able to advance, false if not. + bool have_active_batch = + (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); - if (!empty()) { + if (have_active_batch) { // If we have an active batch - if (++latest_ref_row_ >= (row_index_t)queue_.unsync_front()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t) queue_.UnsyncFront()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; - have_active_batch &= !queue_.try_pop(); + have_active_batch &= !queue_.TryPop(); if (have_active_batch) - DCHECK_GT(queue_.unsync_front()->num_rows(), 0); // empty batches disallowed + DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed } } return have_active_batch; @@ -224,34 +227,34 @@ class InputState { // Advance the data to be immediately past the specified TS, updating latest and // latest_ref_row to the latest data prior to that immediate just past Returns true if // updates were made, false if not. - bool advance_and_memoize(int64_t ts) { + bool AdvanceAndMemoize(int64_t ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. // Check if already updated for TS (or if there is no latest) - if (empty()) return false; // can't advance if empty - auto latest_time = get_latest_time(); + if (Empty()) return false; // can't advance if empty + auto latest_time = GetLatestTime(); if (latest_time > ts) return false; // already advanced // Not updated. Try to update and possibly advance. bool updated = false; do { - latest_time = get_latest_time(); + latest_time = GetLatestTime(); // if advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. if (latest_time <= ts) { - memo_.store(get_latest_batch(), latest_ref_row_, latest_time, get_latest_key()); + memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); } else { break; // hit a future timestamp -- done updating for now } updated = true; - } while (advance()); + } while (Advance()); return updated; } - void push(const std::shared_ptr& rb) { + void Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { queue_.Push(rb); } else { @@ -259,19 +262,19 @@ class InputState { } } - util::optional get_memo_entry_for_key(KeyType key) { - auto r = memo_.get_entry_for_key(key); + util::optional GetMemoEntryForKey(KeyType key) { + auto r = memo_.GetEntryForKey(key); if (r.has_value()) return r; return r; } - util::optional get_memo_time_for_key(KeyType key) { - auto r = get_memo_entry_for_key(key); - return r.has_value() ? (*r)->_time : util::nullopt; + util::optional GetMemoTimeForKey(KeyType key) { + auto r = GetMemoEntryForKey(key); + return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; } - void remove_memo_entries_with_lesser_time(int64_t ts) { - memo_.remove_entries_with_lesser_time(ts); + void RemoveMemoEntriesWithLesserTime(int64_t ts) { + memo_.RemoveEntriesWithLesserTime(ts); } const std::shared_ptr& get_schema() const { return schema_; } @@ -344,19 +347,19 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void emplace(std::vector>& in, int64_t tolerance) { + void Emplace(std::vector>& in, int64_t tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key - KeyType key = in[0]->get_latest_key(); + KeyType key = in[0]->GetLatestKey(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) - DCHECK(!in[0]->empty()); + DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = - in[0]->get_latest_batch(); - row_index_t lhs_latest_row = in[0]->get_latest_row(); - int64_t lhs_latest_time = in[0]->get_latest_time(); + in[0]->GetLatestBatch(); + row_index_t lhs_latest_row = in[0]->GetLatestRow(); + int64_t lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -368,13 +371,13 @@ class CompositeReferenceTable { auto& row = rows_.back(); row.refs[0].batch = lhs_latest_batch.get(); row.refs[0].row = lhs_latest_row; - add_record_batch_ref(lhs_latest_batch); + AddRecordBatchRef(lhs_latest_batch); // Get the state for that key from all on the RHS -- assumes it's up to date // (the RHS state comes from the memoized row references) for (size_t i = 1; i < in.size(); ++i) { util::optional opt_entry = - in[i]->get_memo_entry_for_key(key); + in[i]->GetMemoEntryForKey(key); if (opt_entry.has_value()) { DCHECK(*opt_entry); if ((*opt_entry)->_time + tolerance >= lhs_latest_time) { @@ -382,7 +385,7 @@ class CompositeReferenceTable { const MemoStore::Entry* entry = *opt_entry; row.refs[i].batch = entry->_batch.get(); row.refs[i].row = entry->_row; - add_record_batch_ref(entry->_batch); + AddRecordBatchRef(entry->_batch); continue; } } @@ -392,7 +395,7 @@ class CompositeReferenceTable { } // Materializes the current reference table into a target record batch - Result> materialize( + Result> Materialize( const std::shared_ptr& output_schema, const std::vector>& state) { // cerr << "materialize BEGIN\n"; @@ -410,7 +413,7 @@ class CompositeReferenceTable { { for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { util::optional i_dst_col_opt = - state[i_table]->map_src_to_dst(i_src_col); + state[i_table]->MapSrcToDst(i_src_col); if (!i_dst_col_opt) continue; col_index_t i_dst_col = *i_dst_col_opt; const auto& src_field = state[i_table]->get_schema()->field(i_src_col); @@ -422,18 +425,18 @@ class CompositeReferenceTable { if (field_type->Equals(arrow::int32())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), - (materialize_primitive_column(i_table, - i_src_col))); + (MaterializePrimitiveColumn(i_table, + i_src_col))); } else if (field_type->Equals(arrow::int64())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), - (materialize_primitive_column(i_table, - i_src_col))); + (MaterializePrimitiveColumn(i_table, + i_src_col))); } else if (field_type->Equals(arrow::float64())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), - (materialize_primitive_column(i_table, - i_src_col))); + (MaterializePrimitiveColumn(i_table, + i_src_col))); } else { ARROW_RETURN_NOT_OK( Status::Invalid("Unsupported data type: ", src_field->name())); @@ -465,13 +468,13 @@ class CompositeReferenceTable { size_t n_tables_; // Adds a RecordBatch ref to the mapping, if needed - void add_record_batch_ref(const std::shared_ptr& ref) { + void AddRecordBatchRef(const std::shared_ptr& ref) { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } template - Result> materialize_primitive_column(size_t i_table, - col_index_t i_col) { + Result> MaterializePrimitiveColumn(size_t i_table, + col_index_t i_col) { Builder builder; ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { @@ -496,44 +499,37 @@ class AsofJoinSchema { }; class AsofJoinNode : public ExecNode { - // Constructs labels for inputs - static std::vector build_input_labels( - const std::vector& inputs) { - std::vector r(inputs.size()); - for (size_t i = 0; i < r.size(); ++i) r[i] = "input_" + std::to_string(i) + "_label"; - return r; - } // Advances the RHS as far as possible to be up to date for the current LHS timestamp - bool update_rhs() { + bool UpdateRhs() { auto& lhs = *state_.at(0); - auto lhs_latest_time = lhs.get_latest_time(); + auto lhs_latest_time = lhs.GetLatestTime(); bool any_updated = false; for (size_t i = 1; i < state_.size(); ++i) - any_updated |= state_[i]->advance_and_memoize(lhs_latest_time); + any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time); return any_updated; } // Returns false if RHS not up to date for LHS - bool is_up_to_date_for_lhs_row() const { + bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; - if (lhs.empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts = lhs.get_latest_time(); + if (lhs.Empty()) return false; // can't proceed if nothing on the LHS + int64_t lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; - if (!rhs.finished()) { + if (!rhs.Finished()) { // If RHS is finished, then we know it's up to date (but if it isn't, it might be // up to date) - if (rhs.empty()) + if (rhs.Empty()) return false; // RHS isn't finished, but is empty --> not up to date - if (lhs_ts >= rhs.get_latest_time()) + if (lhs_ts >= rhs.GetLatestTime()) return false; // TS not up to date (and not finished) } } return true; } - Result> process_inner() { + Result> ProcessInner() { assert(!state_.empty()); auto& lhs = *state_.at(0); @@ -545,24 +541,24 @@ class AsofJoinNode : public ExecNode { // limit, or run out of input for (;;) { // If LHS is finished or empty then there's nothing we can do here - if (lhs.finished() || lhs.empty()) break; + if (lhs.Finished() || lhs.Empty()) break; // Advance each of the RHS as far as possible to be up to date for the LHS timestamp - bool any_advanced = update_rhs(); + bool any_advanced = UpdateRhs(); // Only update if we have up-to-date information for the LHS row - if (is_up_to_date_for_lhs_row()) { - dst.emplace(state_, options_.tolerance); - if (!lhs.advance()) break; // if we can't advance LHS, we're done for this batch + if (IsUpToDateWithLhsRow()) { + dst.Emplace(state_, options_.tolerance); + if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch } else { if ((!any_advanced) && (state_.size() > 1)) break; // need to wait for new data } } // Prune memo entries that have expired (to bound memory consumption) - if (!lhs.empty()) { + if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { - state_[i]->remove_memo_entries_with_lesser_time(lhs.get_latest_time() - + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - options_.tolerance); } } @@ -571,11 +567,11 @@ class AsofJoinNode : public ExecNode { if (dst.empty()) { return NULLPTR; } else { - return dst.materialize(output_schema(), state_); + return dst.Materialize(output_schema(), state_); } } - void process() { + void Process() { std::cerr << "process() begin\n"; std::lock_guard guard(gate_); @@ -586,7 +582,7 @@ class AsofJoinNode : public ExecNode { // Process batches while we have data for (;;) { - Result> result = process_inner(); + Result> result = ProcessInner(); if (result.ok()) { auto out_rb = *result; @@ -608,24 +604,24 @@ class AsofJoinNode : public ExecNode { // // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) - if (state_.at(0)->finished()) { + if (state_.at(0)->Finished()) { StopProducing(); outputs_[0]->InputFinished(this, batches_produced_); } } - void process_thread() { + void ProcessThread() { std::cerr << "AsofJoinNode::process_thread started.\n"; for (;;) { if (!process_.Pop()) { std::cerr << "AsofJoinNode::process_thread done.\n"; return; } - process(); + Process(); } } - static void process_thread_wrapper(AsofJoinNode* node) { node->process_thread(); } + static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, @@ -669,7 +665,7 @@ class AsofJoinNode : public ExecNode { // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - state_.at(k)->push(rb); + state_.at(k)->Push(rb); process_.Push(true); std::cerr << "InputReceived END\n"; @@ -774,13 +770,13 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, /*num_outputs=*/1), options_(join_options), process_(), - process_thread_(&AsofJoinNode::process_thread_wrapper, this) { + process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { for (size_t i = 0; i < inputs.size(); ++i) state_.push_back(::arrow::internal::make_unique( inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); col_index_t dst_offset = 0; for (auto& state : state_) - dst_offset = state->init_src_to_dst_mapping(dst_offset, !!dst_offset); + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); finished_ = arrow::Future<>::MakeFinished(); } From eead16e1dd37c5093e57cc24d8c98e819548118c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 May 2022 17:41:44 -0400 Subject: [PATCH 33/47] ninja lint --- cpp/src/arrow/compute/exec/asof_join_node.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 65fac041c3a..22877b305df 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -280,8 +280,8 @@ class InputState { const std::shared_ptr& get_schema() const { return schema_; } void set_total_batches(int n) { - assert(n >= 0); - assert(total_batches_ == -1); // shouldn't be set more than once + DCHECK_GE(n, 0); + DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; total_batches_ = n; } @@ -446,7 +446,7 @@ class CompositeReferenceTable { } // Build the result - DCHECK_GE(sizeof(size_t), sizeof(int64_t)) << "AsofJoinNode requires size_t >= 8 bytes"; + DCHECK_GE(sizeof(size_t), sizeof(int64_t)) << "Requires size_t >= 8 bytes"; // TODO: check n_rows for cast std::shared_ptr r = @@ -499,7 +499,6 @@ class AsofJoinSchema { }; class AsofJoinNode : public ExecNode { - // Advances the RHS as far as possible to be up to date for the current LHS timestamp bool UpdateRhs() { auto& lhs = *state_.at(0); @@ -530,8 +529,7 @@ class AsofJoinNode : public ExecNode { } Result> ProcessInner() { - - assert(!state_.empty()); + DCHECK(!state_.empty()); auto& lhs = *state_.at(0); // Construct new target table if needed From d0aec2faf05b068036d774ef9086f0724e5dd1e8 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 10:39:37 -0400 Subject: [PATCH 34/47] Remove asof_join.h --- cpp/src/arrow/compute/exec/asof_join.h | 42 -------------------------- 1 file changed, 42 deletions(-) delete mode 100644 cpp/src/arrow/compute/exec/asof_join.h diff --git a/cpp/src/arrow/compute/exec/asof_join.h b/cpp/src/arrow/compute/exec/asof_join.h deleted file mode 100644 index 15f9e21c6b2..00000000000 --- a/cpp/src/arrow/compute/exec/asof_join.h +++ /dev/null @@ -1,42 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include "arrow/compute/exec/options.h" -#include "arrow/compute/exec/schema_util.h" - -namespace arrow { -namespace compute { - -typedef int32_t KeyType; - -// Maximum number of tables that can be joined -constexpr int kMaxJoinTables = 64 -typedef uint64_t row_index_t; -typedef int col_index_t; - -class AsofJoinSchema { - public: - std::shared_ptr MakeOutputSchema(const std::vector& inputs, - const AsofJoinNodeOptions& options); -}; - -} // namespace compute -} // namespace arrow From f783818aab9d690623880d500480dce373e400bb Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 15:42:49 -0400 Subject: [PATCH 35/47] Address comments --- cpp/src/arrow/compute/exec/asof_join_node.cc | 45 ++++++++++---------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 22877b305df..86fb6aa0a38 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -192,11 +192,13 @@ class InputState { const std::shared_ptr& GetLatestBatch() const { return queue_.UnsyncFront(); } + KeyType GetLatestKey() const { return queue_.UnsyncFront() ->column_data(key_col_index_) ->GetValues(1)[latest_ref_row_]; } + int64_t GetLatestTime() const { return queue_.UnsyncFront() ->column_data(time_col_index_) @@ -206,13 +208,14 @@ class InputState { bool Finished() const { return batches_processed_ == total_batches_; } bool Advance() { + // Try advancing to the next row and update latest_ref_row_ // Returns true if able to advance, false if not. bool have_active_batch = (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); if (have_active_batch) { // If we have an active batch - if (++latest_ref_row_ >= (row_index_t) queue_.UnsyncFront()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; @@ -224,9 +227,10 @@ class InputState { return have_active_batch; } - // Advance the data to be immediately past the specified TS, updating latest and - // latest_ref_row to the latest data prior to that immediate just past Returns true if - // updates were made, false if not. + // Advance the data to be immediately past the specified timestamp, update + // latest_time and latest_ref_row to the value that immediately pass the + // specified timestamp. + // Returns true if updates were made, false if not. bool AdvanceAndMemoize(int64_t ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. @@ -240,7 +244,7 @@ class InputState { bool updated = false; do { latest_time = GetLatestTime(); - // if advance() returns true, then the latest_ts must also be valid + // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. @@ -447,8 +451,6 @@ class CompositeReferenceTable { // Build the result DCHECK_GE(sizeof(size_t), sizeof(int64_t)) << "Requires size_t >= 8 bytes"; - - // TODO: check n_rows for cast std::shared_ptr r = arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); return r; @@ -517,12 +519,11 @@ class AsofJoinNode : public ExecNode { for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { - // If RHS is finished, then we know it's up to date (but if it isn't, it might be - // up to date) + // If RHS is finished, then we know it's up to date if (rhs.Empty()) return false; // RHS isn't finished, but is empty --> not up to date if (lhs_ts >= rhs.GetLatestTime()) - return false; // TS not up to date (and not finished) + return false; // RHS isn't up to date (and not finished) } } return true; @@ -544,7 +545,11 @@ class AsofJoinNode : public ExecNode { // Advance each of the RHS as far as possible to be up to date for the LHS timestamp bool any_advanced = UpdateRhs(); - // Only update if we have up-to-date information for the LHS row + // If we have received enough inputs to produce the next output batch + // (decided by IsUpToDateWithLhsRow), we will perform the join and + // materialize the output batch. The join is done by advancing through + // the LHS and adding joined row to rows_ (done by Emplace). Finally, + // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { dst.Emplace(state_, options_.tolerance); if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch @@ -557,7 +562,7 @@ class AsofJoinNode : public ExecNode { if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + options_.tolerance); } } @@ -624,8 +629,7 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema, - std::unique_ptr schema_mgr); + std::shared_ptr output_schema); virtual ~AsofJoinNode() { process_.Push(false); // poison pill @@ -634,13 +638,13 @@ class AsofJoinNode : public ExecNode { static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { - std::unique_ptr schema_mgr = - ::arrow::internal::make_unique(); + AsofJoinSchema schema_mgr; const auto& join_options = checked_cast(options); std::shared_ptr output_schema = - schema_mgr->MakeOutputSchema(inputs, join_options); + schema_mgr.MakeOutputSchema(inputs, join_options); + DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; std::vector input_labels(inputs.size()); input_labels[0] = "left"; for (size_t i = 1; i < inputs.size(); ++i) { @@ -648,8 +652,7 @@ class AsofJoinNode : public ExecNode { } return plan->EmplaceNode(plan, inputs, std::move(input_labels), - join_options, std::move(output_schema), - std::move(schema_mgr)); + join_options, std::move(output_schema)); } const char* kind_name() const override { return "AsofJoinNode"; } @@ -717,7 +720,6 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: - std::unique_ptr schema_mgr_; arrow::Future<> finished_; // InputStates // Each input state correponds to an input table @@ -761,8 +763,7 @@ std::shared_ptr AsofJoinSchema::MakeOutputSchema( AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema, - std::unique_ptr schema_mgr) + std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), From 1a79d107eb9bf67019da9897ceb2d314cbbc09c2 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 16:08:41 -0400 Subject: [PATCH 36/47] Clean up tests --- cpp/src/arrow/compute/exec/asof_join_node.cc | 1 - cpp/src/arrow/compute/exec/asof_join_node_test.cc | 12 +++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 86fb6aa0a38..002dfcfe542 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -714,7 +714,6 @@ class AsofJoinNode : public ExecNode { } void StopProducing() override { std::cerr << "StopProducing" << std::endl; - // if(batch_count_.Cancel()) finished_.MarkFinished(); finished_.MarkFinished(); } arrow::Future<> finished() override { return finished_; } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index e2d0b1c83ba..95f81810677 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -96,12 +96,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, AssertTablesEqual(*exp_table, *res_table, /*same_chunk_layout=*/false, /*flatten=*/true); - - std::cerr << "Result Equals" - << "\n"; } -void RunNonEmptyTest(bool exact_matches) { +void RunNonEmptyTest() { auto l_schema = schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); auto r0_schema = @@ -204,12 +201,9 @@ void RunNonEmptyTest(bool exact_matches) { CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 0); } -class AsofJoinTest : public testing::TestWithParam> {}; - -INSTANTIATE_TEST_SUITE_P(AsofJoinTest, AsofJoinTest, - ::testing::Combine(::testing::Values(false, true))); +class AsofJoinTest : public testing::Test {}; -TEST_P(AsofJoinTest, TestExactMatches) { RunNonEmptyTest(std::get<0>(GetParam())); } +TEST(AsofJoinTest, TestBasic) { RunNonEmptyTest(); } } // namespace compute } // namespace arrow From 83398c880a1a1e6bf8c54f21bcc23da1b56b8a9b Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 16:10:12 -0400 Subject: [PATCH 37/47] ninja format --- cpp/src/arrow/compute/exec/asof_join_node.cc | 22 ++++++++------------ cpp/src/arrow/compute/exec/options.h | 1 - 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 002dfcfe542..14487845447 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -157,8 +157,7 @@ class InputState { schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name key_col_index_(schema->GetFieldIndex(key_col_name)) {} - col_index_t InitSrcToDstMapping(col_index_t dst_offset, - bool skip_time_and_key_fields) { + col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) @@ -181,8 +180,7 @@ class InputState { bool Empty() const { // cannot be empty if ref row is >0 -- can avoid slow queue lock // below - if (latest_ref_row_ > 0) - return false; + if (latest_ref_row_ > 0) return false; return queue_.Empty(); } @@ -360,8 +358,7 @@ class CompositeReferenceTable { // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) DCHECK(!in[0]->Empty()); - const std::shared_ptr& lhs_latest_batch = - in[0]->GetLatestBatch(); + const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); int64_t lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { @@ -380,8 +377,7 @@ class CompositeReferenceTable { // Get the state for that key from all on the RHS -- assumes it's up to date // (the RHS state comes from the memoized row references) for (size_t i = 1; i < in.size(); ++i) { - util::optional opt_entry = - in[i]->GetMemoEntryForKey(key); + util::optional opt_entry = in[i]->GetMemoEntryForKey(key); if (opt_entry.has_value()) { DCHECK(*opt_entry); if ((*opt_entry)->_time + tolerance >= lhs_latest_time) { @@ -430,17 +426,17 @@ class CompositeReferenceTable { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), (MaterializePrimitiveColumn(i_table, - i_src_col))); + i_src_col))); } else if (field_type->Equals(arrow::int64())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), (MaterializePrimitiveColumn(i_table, - i_src_col))); + i_src_col))); } else if (field_type->Equals(arrow::float64())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), (MaterializePrimitiveColumn(i_table, - i_src_col))); + i_src_col))); } else { ARROW_RETURN_NOT_OK( Status::Invalid("Unsupported data type: ", src_field->name())); @@ -476,7 +472,7 @@ class CompositeReferenceTable { template Result> MaterializePrimitiveColumn(size_t i_table, - col_index_t i_col) { + col_index_t i_col) { Builder builder; ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { @@ -562,7 +558,7 @@ class AsofJoinNode : public ExecNode { if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + options_.tolerance); } } diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 532bec58492..2f5a9cf7269 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -361,7 +361,6 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { Expression filter; }; - /// \brief Make a node which implements asof join operation /// /// This node takes one left table and (n-1) right tables, and asof joins them From 9f3d5c906fc5f7ce4c1d47d3c0e8f2f4098288b1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 16:40:30 -0400 Subject: [PATCH 38/47] Use implicit ctor for optional --- cpp/src/arrow/compute/exec/asof_join_node.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 14487845447..df0c5adb71d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -272,7 +272,12 @@ class InputState { util::optional GetMemoTimeForKey(KeyType key) { auto r = GetMemoEntryForKey(key); - return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; + // return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; + if (r.has_value()) { + return (*r)->_time; + } else { + return util::nullopt; + } } void RemoveMemoEntriesWithLesserTime(int64_t ts) { @@ -705,8 +710,6 @@ class AsofJoinNode : public ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); StopProducing(); - std::cout << "StopProducing" - << "\n"; } void StopProducing() override { std::cerr << "StopProducing" << std::endl; From e61b9c1b9bd241e0982aefb5e51e5d1c038182b4 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 1 Jun 2022 17:52:29 -0400 Subject: [PATCH 39/47] Refactor tests --- cpp/src/arrow/compute/exec/asof_join_node.cc | 5 +- .../arrow/compute/exec/asof_join_node_test.cc | 226 ++++++++++++------ 2 files changed, 152 insertions(+), 79 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index df0c5adb71d..1a298190e47 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -265,14 +265,11 @@ class InputState { } util::optional GetMemoEntryForKey(KeyType key) { - auto r = memo_.GetEntryForKey(key); - if (r.has_value()) return r; - return r; + return memo_.GetEntryForKey(key); } util::optional GetMemoTimeForKey(KeyType key) { auto r = GetMemoEntryForKey(key); - // return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt; if (r.has_value()) { return (*r)->_time; } else { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 95f81810677..086d33f2ab4 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -95,10 +95,13 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, ASSERT_OK_AND_ASSIGN(auto res_table, TableFromExecBatches(exp_batches.schema, res)); AssertTablesEqual(*exp_table, *res_table, - /*same_chunk_layout=*/false, /*flatten=*/true); + /*same_chunk_layout=*/true, /*flatten=*/true); } -void RunNonEmptyTest() { +void DoRunBasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_data, int64_t tolerance) { auto l_schema = schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); auto r0_schema = @@ -116,94 +119,167 @@ void RunNonEmptyTest() { // Test three table join BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; + l_batches = GenerateBatchesFromString(l_schema, l_data); + r0_batches = GenerateBatchesFromString(r0_schema, r0_data); + r1_batches = GenerateBatchesFromString(r1_schema, r1_data); + exp_batches = GenerateBatchesFromString(exp_schema, exp_data); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); +} + +class AsofJoinTest : public testing::Test {}; +TEST(AsofJoinTest, TestBasic1) { // Single key, single batch - l_batches = GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}); - r0_batches = GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])"}); - r1_batches = GenerateBatchesFromString(r1_schema, {R"([[1000, 1, 101.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0]])"}, + /*r1*/ {R"([[1000, 1, 101.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); +} +TEST(AsofJoinTest, TestBasic2) { // Single key, multiple batches - l_batches = - GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}); - r0_batches = - GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}); - r1_batches = GenerateBatchesFromString(r1_schema, - {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, + /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); +} +TEST(AsofJoinTest, TestBasic3) { // Single key, multiple left batches, single right batches + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, + /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, + /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); +} - l_batches = - GenerateBatchesFromString(l_schema, {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}); - - r0_batches = - GenerateBatchesFromString(r0_schema, {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}); - r1_batches = - GenerateBatchesFromString(r1_schema, {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); - +TEST(AsofJoinTest, TestBasic4) { // Multi key, multiple batches, misaligned batches - l_batches = GenerateBatchesFromString( - l_schema, + DoRunBasicTest( + /*l*/ {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); - r0_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); - r1_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 1000); + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} +TEST(AsofJoinTest, TestBasic5) { // Multi key, multiple batches, misaligned batches, smaller tolerance - l_batches = GenerateBatchesFromString( - l_schema, - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); - r0_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); - r1_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 500); - - // // Multi key, multiple batches, misaligned batches, zero tolerance - l_batches = GenerateBatchesFromString( - l_schema, - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}); - r0_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}); - r1_batches = GenerateBatchesFromString( - r0_schema, {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}); - exp_batches = GenerateBatchesFromString( - exp_schema, - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", 0); + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 500); } -class AsofJoinTest : public testing::Test {}; +TEST(AsofJoinTest, TestBasic6) { + // Multi key, multiple batches, misaligned batches, zero tolerance + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, + 0); +} + +TEST(AsofJoinTest, TestEmpty1) { + // Empty left batch + DoRunBasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty2) { + // Empty left input + DoRunBasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", + R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([])"}, 1000); +} -TEST(AsofJoinTest, TestBasic) { RunNonEmptyTest(); } +TEST(AsofJoinTest, TestEmpty3) { + // Empty right batch + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", + R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty4) { + // Empty right input + DoRunBasicTest(/*l*/ + {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", + R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001.0], [500, 1, 101.0]])", + R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + /*exp*/ + {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", + R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, + 1000); +} + +TEST(AsofJoinTest, TestEmpty5) { + // All empty + DoRunBasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} } // namespace compute } // namespace arrow From 1b4f26b2382ac68fadf579cbd62be4b5a15fbd01 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 2 Jun 2022 14:27:24 -0400 Subject: [PATCH 40/47] Address comments and add check/test for unsuported datatypes --- cpp/src/arrow/compute/exec/asof_join_node.cc | 93 ++++++++++++------- .../arrow/compute/exec/asof_join_node_test.cc | 42 ++++++++- 2 files changed, 97 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 1a298190e47..a5184c54d35 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include @@ -402,7 +403,6 @@ class CompositeReferenceTable { const std::vector>& state) { // cerr << "materialize BEGIN\n"; DCHECK_EQ(state.size(), n_tables_); - DCHECK_GE(state.size(), 1); // Don't build empty batches size_t n_rows = rows_.size(); @@ -434,6 +434,10 @@ class CompositeReferenceTable { arrays.at(i_dst_col), (MaterializePrimitiveColumn(i_table, i_src_col))); + } else if (field_type->Equals(arrow::float32())) { + ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), + (MaterializePrimitiveColumn( + i_table, i_src_col))); } else if (field_type->Equals(arrow::float64())) { ARROW_ASSIGN_OR_RAISE( arrays.at(i_dst_col), @@ -492,12 +496,6 @@ class CompositeReferenceTable { } }; -class AsofJoinSchema { - public: - std::shared_ptr MakeOutputSchema(const std::vector& inputs, - const AsofJoinNodeOptions& options); -}; - class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp bool UpdateRhs() { @@ -541,7 +539,7 @@ class AsofJoinNode : public ExecNode { if (lhs.Finished() || lhs.Empty()) break; // Advance each of the RHS as far as possible to be up to date for the LHS timestamp - bool any_advanced = UpdateRhs(); + bool any_rhs_advanced = UpdateRhs(); // If we have received enough inputs to produce the next output batch // (decided by IsUpToDateWithLhsRow), we will perform the join and @@ -552,7 +550,7 @@ class AsofJoinNode : public ExecNode { dst.Emplace(state_, options_.tolerance); if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch } else { - if ((!any_advanced) && (state_.size() > 1)) break; // need to wait for new data + if (!any_rhs_advanced) break; // need to wait for new data } } @@ -634,15 +632,51 @@ class AsofJoinNode : public ExecNode { process_thread_.join(); } + static arrow::Result> MakeOutputSchema( + const std::vector& inputs, const AsofJoinNodeOptions& options) { + std::vector> fields; + + // Take all non-key, non-time RHS fields + for (size_t j = 0; j < inputs.size(); ++j) { + const auto& input_schema = inputs[j]->output_schema(); + for (int i = 0; i < input_schema->num_fields(); ++i) { + const auto field = input_schema->field(i); + if (field->name() == *options.on_key.name()) { + if (supported_on_types_.find(field->type()) == supported_on_types_.end()) { + return Status::Invalid("Unsupported type for on key: ", field->type()); + } + // Only add on field from the left table + if (j == 0) { + fields.push_back(field); + } + } else if (field->name() == *options.by_key.name()) { + if (supported_by_types_.find(field->type()) == supported_by_types_.end()) { + return Status::Invalid("Unsupported type for by key: ", field->type()); + } + // Only add by field from the left table + if (j == 0) { + fields.push_back(field); + } + } else { + if (supported_data_types_.find(field->type()) == supported_data_types_.end()) { + return Status::Invalid("Unsupported data type:", field->type()); + } + + fields.push_back(field); + } + } + } + return std::make_shared(fields); + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { - AsofJoinSchema schema_mgr; + DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - std::shared_ptr output_schema = - schema_mgr.MakeOutputSchema(inputs, join_options); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, join_options)); - DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; std::vector input_labels(inputs.size()); input_labels[0] = "left"; for (size_t i = 1; i < inputs.size(); ++i) { @@ -715,6 +749,10 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: + static const std::set> supported_on_types_; + static const std::set> supported_by_types_; + static const std::set> supported_data_types_; + arrow::Future<> finished_; // InputStates // Each input state correponds to an input table @@ -732,29 +770,6 @@ class AsofJoinNode : public ExecNode { int batches_produced_ = 0; }; -std::shared_ptr AsofJoinSchema::MakeOutputSchema( - const std::vector& inputs, const AsofJoinNodeOptions& options) { - std::vector> fields; - DCHECK_GT(inputs.size(), 1); - - // Directly map LHS fields - for (int i = 0; i < inputs[0]->output_schema()->num_fields(); ++i) - fields.push_back(inputs[0]->output_schema()->field(i)); - - // Take all non-key, non-time RHS fields - for (size_t j = 1; j < inputs.size(); ++j) { - const auto& input_schema = inputs[j]->output_schema(); - for (int i = 0; i < input_schema->num_fields(); ++i) { - const auto& name = input_schema->field(i)->name(); - if ((name != *options.by_key.name()) && (name != *options.on_key.name())) { - fields.push_back(input_schema->field(i)); - } - } - } - - return std::make_shared(fields); -} - AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const AsofJoinNodeOptions& join_options, @@ -775,6 +790,12 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, finished_ = arrow::Future<>::MakeFinished(); } +// Currently supported types +const std::set> AsofJoinNode::supported_on_types_ = {int64()}; +const std::set> AsofJoinNode::supported_by_types_ = {int32()}; +const std::set> AsofJoinNode::supported_data_types_ = { + int32(), int64(), float32(), float64()}; + namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 086d33f2ab4..5d1a6ce110c 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -107,14 +107,14 @@ void DoRunBasicTest(const std::vector& l_data, auto r0_schema = schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); auto r1_schema = - schema({field("time", int64()), field("key", int32()), field("r1_v0", float64())}); + schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); auto exp_schema = schema({ field("time", int64()), field("key", int32()), field("l_v0", float64()), field("r0_v0", float64()), - field("r1_v0", float64()), + field("r1_v0", float32()), }); // Test three table join @@ -127,6 +127,25 @@ void DoRunBasicTest(const std::vector& l_data, tolerance); } +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + BatchesWithSchema l_batches = GenerateBatchesFromString(l_schema, {R"([])"}); + BatchesWithSchema r_batches = GenerateBatchesFromString(r_schema, {R"([])"}); + + auto exec_ctx = + arrow::internal::make_unique(default_memory_pool(), nullptr); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + AsofJoinNodeOptions join_options("time", "key", 0); + Declaration join{"asofjoin", join_options}; + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); + + ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); +} + class AsofJoinTest : public testing::Test {}; TEST(AsofJoinTest, TestBasic1) { @@ -281,5 +300,24 @@ TEST(AsofJoinTest, TestEmpty5) { {R"([])"}, 1000); } +TEST(AsofJoinTest, TestUnsupportedOntype) { + DoRunInvalidTypeTest( + schema({field("time", utf8()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", utf8()), field("key", int32()), field("r0_v0", float32())})); +} + +TEST(AsofJoinTest, TestUnsupportedBytype) { + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key", utf8()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", utf8()), field("r0_v0", float32())})); +} + +TEST(AsofJoinTest, TestUnsupportedDatatype) { + // Utf8 is unsupported + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); +} + } // namespace compute } // namespace arrow From 121b3bdf35fe7b1b644136b827951c015cef45d2 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Jun 2022 11:17:18 -0400 Subject: [PATCH 41/47] Apply suggestions from code review Co-authored-by: Weston Pace --- cpp/src/arrow/compute/exec/options.h | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 2f5a9cf7269..add612f5d47 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -363,9 +363,11 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// \brief Make a node which implements asof join operation /// -/// This node takes one left table and (n-1) right tables, and asof joins them -/// together. Batches produced by each inputs must be ordered by the "on" key. -/// The batch size that this node produces is decided by the left table. +/// Note, this API is experimental and will change in the future +/// +/// This node takes one left table and any number of right tables, and asof joins them +/// together. Batches produced by each input must be ordered by the "on" key. +/// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) @@ -376,9 +378,9 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { // left_on - tolerance <= right_on <= left_on. // Currently, "on" key must be an int64 field FieldRef on_key; - // "by" key for the join. All tables must have the "by" key. Equity prediciate - // are used on the "by" key. - // Currently, "by" key must be an int32 field + // "by" key for the join. All tables must have the "by" key. Exact equality + // is used for the "by" key. + // Currently, the "by" key must be an int32 field FieldRef by_key; // tolerance for inexact "on" key matching int64_t tolerance; From 66b6b98c38ab7f2f1e372de401839b6fe32a9d5a Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 14 Jun 2022 11:35:18 -0400 Subject: [PATCH 42/47] Apply suggestions from code review Co-authored-by: Weston Pace Co-authored-by: Ivan Chau --- cpp/src/arrow/compute/exec/asof_join_node.cc | 14 +++++++------- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 5 ++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index a5184c54d35..f22f4b19a9e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -227,7 +227,7 @@ class InputState { } // Advance the data to be immediately past the specified timestamp, update - // latest_time and latest_ref_row to the value that immediately pass the + // latest_time and latest_ref_row to the value that immediately follows the // specified timestamp. // Returns true if updates were made, false if not. bool AdvanceAndMemoize(int64_t ts) { @@ -643,7 +643,7 @@ class AsofJoinNode : public ExecNode { const auto field = input_schema->field(i); if (field->name() == *options.on_key.name()) { if (supported_on_types_.find(field->type()) == supported_on_types_.end()) { - return Status::Invalid("Unsupported type for on key: ", field->type()); + return Status::Invalid("Unsupported type for on key: ", field->name()); } // Only add on field from the left table if (j == 0) { @@ -651,7 +651,7 @@ class AsofJoinNode : public ExecNode { } } else if (field->name() == *options.by_key.name()) { if (supported_by_types_.find(field->type()) == supported_by_types_.end()) { - return Status::Invalid("Unsupported type for by key: ", field->type()); + return Status::Invalid("Unsupported type for by key: ", field->name()); } // Only add by field from the left table if (j == 0) { @@ -659,7 +659,7 @@ class AsofJoinNode : public ExecNode { } } else { if (supported_data_types_.find(field->type()) == supported_data_types_.end()) { - return Status::Invalid("Unsupported data type:", field->type()); + return Status::Invalid("Unsupported data type:", field->name()); } fields.push_back(field); @@ -749,9 +749,9 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: - static const std::set> supported_on_types_; - static const std::set> supported_by_types_; - static const std::set> supported_data_types_; + static const std::set> kSupportedOnTypes; + static const std::set> kSupportedByTypes; + static const std::set> kSupportedDataTypes; arrow::Future<> finished_; // InputStates diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 5d1a6ce110c..cd2262a4a74 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -132,9 +132,8 @@ void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, BatchesWithSchema l_batches = GenerateBatchesFromString(l_schema, {R"([])"}); BatchesWithSchema r_batches = GenerateBatchesFromString(r_schema, {R"([])"}); - auto exec_ctx = - arrow::internal::make_unique(default_memory_pool(), nullptr); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + ExecContext exec_ctx; + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); AsofJoinNodeOptions join_options("time", "key", 0); Declaration join{"asofjoin", join_options}; From 478f3b7dd17528dc32f156c84483e0464ff91e5a Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 21 Jun 2022 14:51:23 -0400 Subject: [PATCH 43/47] Address comments --- cpp/src/arrow/compute/exec/asof_join_node.cc | 111 +++++++++--------- .../arrow/compute/exec/asof_join_node_test.cc | 48 +++----- cpp/src/arrow/compute/exec/options.h | 20 ++-- cpp/src/arrow/compute/exec/test_util.cc | 24 ++++ cpp/src/arrow/compute/exec/test_util.h | 5 + 5 files changed, 114 insertions(+), 94 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index f22f4b19a9e..d7e20f5a5a4 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -15,26 +15,26 @@ // specific language governing permissions and limitations // under the License. +#include #include +#include #include +#include +#include #include -#include -#include -#include +#include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" - -#include -#include -#include +#include "arrow/util/optional.h" namespace arrow { namespace compute { @@ -101,46 +101,41 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - int64_t _time; + int64_t time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) - std::shared_ptr _batch; + std::shared_ptr batch; // Row associated with the entry - row_index_t _row; + row_index_t row; }; - std::unordered_map _entries; + std::unordered_map entries_; void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, KeyType key) { - auto& e = _entries[key]; + auto& e = entries_[key]; // that we can do this assignment optionally, is why we // can get array with using shared_ptr above (the batch // shouldn't change that often) - if (e._batch != batch) e._batch = batch; - e._row = row; - e._time = time; + if (e.batch != batch) e.batch = batch; + e.row = row; + e.time = time; } util::optional GetEntryForKey(KeyType key) const { - auto e = _entries.find(key); - if (_entries.end() == e) return util::nullopt; + auto e = entries_.find(key); + if (entries_.end() == e) return util::nullopt; return util::optional(&e->second); } void RemoveEntriesWithLesserTime(int64_t ts) { - size_t dbg_size0 = _entries.size(); - for (auto e = _entries.begin(); e != _entries.end();) - if (e->second._time < ts) - e = _entries.erase(e); + for (auto e = entries_.begin(); e != entries_.end();) + if (e->second.time < ts) + e = entries_.erase(e); else ++e; - size_t dbg_size1 = _entries.size(); - if (dbg_size1 < dbg_size0) { - // cerr << "Removed " << dbg_size0-dbg_size1 << " memo entries.\n"; - } } }; @@ -154,8 +149,7 @@ class InputState { const std::string& time_col_name, const std::string& key_col_name) : queue_(), schema_(schema), - time_col_index_( - schema->GetFieldIndex(time_col_name)), // TODO: handle missing field name + time_col_index_(schema->GetFieldIndex(time_col_name)), key_col_index_(schema->GetFieldIndex(key_col_name)) {} col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { @@ -227,7 +221,7 @@ class InputState { } // Advance the data to be immediately past the specified timestamp, update - // latest_time and latest_ref_row to the value that immediately follows the + // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. bool AdvanceAndMemoize(int64_t ts) { @@ -272,7 +266,7 @@ class InputState { util::optional GetMemoTimeForKey(KeyType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { - return (*r)->_time; + return (*r)->time; } else { return util::nullopt; } @@ -291,31 +285,23 @@ class InputState { } private: - // Pending record batches. The latest is the front. Batches cannot be empty. + // Pending record batches. The latest is the front. Batches cannot be empty. ConcurrentQueue> queue_; - // Schema associated with the input std::shared_ptr schema_; - // Total number of batches (only int because InputFinished uses int) int total_batches_ = -1; - // Number of batches processed so far (only int because InputFinished uses int) int batches_processed_ = 0; - // Index of the time col col_index_t time_col_index_; - // Index of the key col col_index_t key_col_index_; - // Index of the latest row reference within; if >0 then queue_ cannot be empty - row_index_t latest_ref_row_ = - 0; // must be < queue_.front()->num_rows() if queue_ is non-empty - + // Must be < queue_.front()->num_rows() if queue_ is non-empty + row_index_t latest_ref_row_ = 0; // Stores latest known values for the various keys MemoStore memo_; - // Mapping of source columns to destination columns std::vector> src_to_dst_; }; @@ -383,12 +369,12 @@ class CompositeReferenceTable { util::optional opt_entry = in[i]->GetMemoEntryForKey(key); if (opt_entry.has_value()) { DCHECK(*opt_entry); - if ((*opt_entry)->_time + tolerance >= lhs_latest_time) { + if ((*opt_entry)->time + tolerance >= lhs_latest_time) { // Have a valid entry const MemoStore::Entry* entry = *opt_entry; - row.refs[i].batch = entry->_batch.get(); - row.refs[i].row = entry->_row; - AddRecordBatchRef(entry->_batch); + row.refs[i].batch = entry->batch.get(); + row.refs[i].row = entry->row; + AddRecordBatchRef(entry->batch); continue; } } @@ -636,21 +622,31 @@ class AsofJoinNode : public ExecNode { const std::vector& inputs, const AsofJoinNodeOptions& options) { std::vector> fields; + const auto& on_field_name = *options.on_key.name(); + const auto& by_field_name = *options.by_key.name(); + // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); + const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); + const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + + if ((on_field_ix == -1) | (by_field_ix == -1)) { + return Status::Invalid("Missing join key on table ", j); + } + for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); - if (field->name() == *options.on_key.name()) { - if (supported_on_types_.find(field->type()) == supported_on_types_.end()) { + if (field->name() == on_field_name) { + if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { return Status::Invalid("Unsupported type for on key: ", field->name()); } // Only add on field from the left table if (j == 0) { fields.push_back(field); } - } else if (field->name() == *options.by_key.name()) { - if (supported_by_types_.find(field->type()) == supported_by_types_.end()) { + } else if (field->name() == by_field_name) { + if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { return Status::Invalid("Unsupported type for by key: ", field->name()); } // Only add by field from the left table @@ -658,8 +654,8 @@ class AsofJoinNode : public ExecNode { fields.push_back(field); } } else { - if (supported_data_types_.find(field->type()) == supported_data_types_.end()) { - return Status::Invalid("Unsupported data type:", field->name()); + if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { + return Status::Invalid("Unsupported data type: ", field->name()); } fields.push_back(field); @@ -697,6 +693,9 @@ class AsofJoinNode : public ExecNode { // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); + std::stringstream stream; + stream << "(k=" << k << ") time=(" << *rb->columns()[0] << ")\n"; + std::cerr << stream.str(); state_.at(k)->Push(rb); process_.Push(true); @@ -749,9 +748,9 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: - static const std::set> kSupportedOnTypes; - static const std::set> kSupportedByTypes; - static const std::set> kSupportedDataTypes; + static const std::set> kSupportedOnTypes_; + static const std::set> kSupportedByTypes_; + static const std::set> kSupportedDataTypes_; arrow::Future<> finished_; // InputStates @@ -791,9 +790,9 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, } // Currently supported types -const std::set> AsofJoinNode::supported_on_types_ = {int64()}; -const std::set> AsofJoinNode::supported_by_types_ = {int32()}; -const std::set> AsofJoinNode::supported_data_types_ = { +const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; +const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; +const std::set> AsofJoinNode::kSupportedDataTypes_ = { int32(), int64(), float32(), float64()}; namespace internal { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index cd2262a4a74..8b993764abe 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -39,30 +39,6 @@ using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { -BatchesWithSchema GenerateBatchesFromString( - const std::shared_ptr& schema, - const std::vector& json_strings, int multiplicity = 1) { - BatchesWithSchema out_batches{{}, schema}; - - std::vector descrs; - for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); - } - - for (auto&& s : json_strings) { - out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); - } - - size_t batch_count = out_batches.batches.size(); - for (int repeat = 1; repeat < multiplicity; ++repeat) { - for (size_t i = 0; i < batch_count; ++i) { - out_batches.batches.push_back(out_batches.batches[i]); - } - } - - return out_batches; -} - void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, const BatchesWithSchema& r1_batches, @@ -119,18 +95,18 @@ void DoRunBasicTest(const std::vector& l_data, // Test three table join BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; - l_batches = GenerateBatchesFromString(l_schema, l_data); - r0_batches = GenerateBatchesFromString(r0_schema, r0_data); - r1_batches = GenerateBatchesFromString(r1_schema, r1_data); - exp_batches = GenerateBatchesFromString(exp_schema, exp_data); + l_batches = MakeBatchesFromString(l_schema, l_data); + r0_batches = MakeBatchesFromString(r0_schema, r0_data); + r1_batches = MakeBatchesFromString(r1_schema, r1_data); + exp_batches = MakeBatchesFromString(exp_schema, exp_data); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", tolerance); } void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - BatchesWithSchema l_batches = GenerateBatchesFromString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = GenerateBatchesFromString(r_schema, {R"([])"}); + BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); + BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); @@ -318,5 +294,17 @@ TEST(AsofJoinTest, TestUnsupportedDatatype) { schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); } +TEST(AsofJoinTest, TestMissingKeys) { + DoRunInvalidTypeTest( + schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), + schema( + {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); + + DoRunInvalidTypeTest( + schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), + schema( + {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index add612f5d47..355b9083b03 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -373,16 +373,20 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} - // "on" key for the join. Each input table must be sorted by the "on" key. Inexact - // match is used on the "on" key. i.e., a row is considiered match iff - // left_on - tolerance <= right_on <= left_on. - // Currently, "on" key must be an int64 field + /// \brief "on" key for the join. Each + /// + /// All inputs tables must be sorted by the "on" key. Inexact + /// match is used on the "on" key. i.e., a row is considiered match iff + /// left_on - tolerance <= right_on <= left_on. + /// Currently, "on" key must be an int64 field FieldRef on_key; - // "by" key for the join. All tables must have the "by" key. Exact equality - // is used for the "by" key. - // Currently, the "by" key must be an int32 field + /// \brief "by" key for the join. + /// + /// All input tables must have the "by" key. Exact equality + /// is used for the "by" key. + /// Currently, the "by" key must be an int32 field FieldRef by_key; - // tolerance for inexact "on" key matching + /// Tolerance for inexact "on" key matching int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 41eb401ced6..2eabc1d1c26 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -221,6 +221,30 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, return out; } +BatchesWithSchema MakeBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity) { + BatchesWithSchema out_batches{{}, schema}; + + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches.batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches.batches.push_back(out_batches.batches[i]); + } + } + + return out_batches; +} + Result> SortTableOnAllFields(const std::shared_ptr& tab) { std::vector sort_keys; for (auto&& f : tab->schema()->fields()) { diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 9347d1343f1..c69f1b94446 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -96,6 +96,11 @@ ARROW_TESTING_EXPORT BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, int num_batches = 10, int batch_size = 4); +ARROW_TESTING_EXPORT +BatchesWithSchema MakeBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1); + ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab); From b33effa241c0285056479f400ced0f6147b68648 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 21 Jun 2022 14:58:33 -0400 Subject: [PATCH 44/47] Remove debug statement --- cpp/src/arrow/compute/exec/asof_join_node.cc | 38 ++------------------ 1 file changed, 3 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index d7e20f5a5a4..4d473276d18 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -16,10 +16,8 @@ // under the License. #include -#include #include #include -#include #include #include @@ -387,7 +385,6 @@ class CompositeReferenceTable { Result> Materialize( const std::shared_ptr& output_schema, const std::vector>& state) { - // cerr << "materialize BEGIN\n"; DCHECK_EQ(state.size(), n_tables_); // Don't build empty batches @@ -557,11 +554,8 @@ class AsofJoinNode : public ExecNode { } void Process() { - std::cerr << "process() begin\n"; - std::lock_guard guard(gate_); if (finished_.is_finished()) { - std::cerr << "InputReceived EARLYEND\n"; return; } @@ -582,8 +576,6 @@ class AsofJoinNode : public ExecNode { } } - std::cerr << "process() end\n"; - // Report to the output the total batch count, if we've already finished everything // (there are two places where this can happen: here and InputFinished) // @@ -596,10 +588,8 @@ class AsofJoinNode : public ExecNode { } void ProcessThread() { - std::cerr << "AsofJoinNode::process_thread started.\n"; for (;;) { if (!process_.Pop()) { - std::cerr << "AsofJoinNode::process_thread done.\n"; return; } Process(); @@ -689,28 +679,19 @@ class AsofJoinNode : public ExecNode { // Get the input ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); - std::cerr << "InputReceived BEGIN (k=" << k << ")\n"; // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - std::stringstream stream; - stream << "(k=" << k << ") time=(" << *rb->columns()[0] << ")\n"; - std::cerr << stream.str(); - state_.at(k)->Push(rb); process_.Push(true); - - std::cerr << "InputReceived END\n"; } void ErrorReceived(ExecNode* input, Status error) override { outputs_[0]->ErrorReceived(this, std::move(error)); StopProducing(); } void InputFinished(ExecNode* input, int total_batches) override { - std::cerr << "InputFinished BEGIN\n"; { std::lock_guard guard(gate_); - std::cerr << "InputFinished find\n"; ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); state_.at(k)->set_total_batches(total_batches); @@ -720,31 +701,18 @@ class AsofJoinNode : public ExecNode { // know whether the RHS of the join is up-to-date until we know that the table is // finished. process_.Push(true); - - std::cerr << "InputFinished END\n"; } Status StartProducing() override { - std::cout << "StartProducing" - << "\n"; finished_ = arrow::Future<>::Make(); return Status::OK(); } - void PauseProducing(ExecNode* output, int32_t counter) override { - std::cout << "PauseProducing" - << "\n"; - } - void ResumeProducing(ExecNode* output, int32_t counter) override { - std::cout << "ResumeProducing" - << "\n"; - } + void PauseProducing(ExecNode* output, int32_t counter) override {} + void ResumeProducing(ExecNode* output, int32_t counter) override {} void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); StopProducing(); } - void StopProducing() override { - std::cerr << "StopProducing" << std::endl; - finished_.MarkFinished(); - } + void StopProducing() override { finished_.MarkFinished(); } arrow::Future<> finished() override { return finished_; } private: From be417c3b1dc995d59f2a59d00cec1edb56c77338 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 21 Jun 2022 15:02:45 -0400 Subject: [PATCH 45/47] Fix minor comment --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 4d473276d18..09214e3a265 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -435,7 +435,7 @@ class CompositeReferenceTable { } // Build the result - DCHECK_GE(sizeof(size_t), sizeof(int64_t)) << "Requires size_t >= 8 bytes"; + DCHECK_LE(n_rows, std::numeric_limits::max()); std::shared_ptr r = arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); return r; From 1261e20e8a5ad0b00b669c0327f05ab8aa74ec41 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 21 Jun 2022 15:58:21 -0400 Subject: [PATCH 46/47] Fix lint --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 09214e3a265..24c118cd4dc 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -435,7 +435,7 @@ class CompositeReferenceTable { } // Build the result - DCHECK_LE(n_rows, std::numeric_limits::max()); + DCHECK_LE(n_rows, std::numeric_limits::max()); std::shared_ptr r = arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); return r; From 26fae6d5c104f8180532b6254a030b7e7c97e5c0 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 21 Jun 2022 17:14:56 -0400 Subject: [PATCH 47/47] Fix lint again --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 24c118cd4dc..93eca8dbfb6 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -435,7 +435,7 @@ class CompositeReferenceTable { } // Build the result - DCHECK_LE(n_rows, std::numeric_limits::max()); + DCHECK_LE(n_rows, (uint64_t)std::numeric_limits::max()); std::shared_ptr r = arrow::RecordBatch::Make(output_schema, (int64_t)n_rows, arrays); return r;