diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 1fc6db642e0..87349191e90 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -26,6 +26,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/expression.h" +#include "arrow/util/async_util.h" #include "arrow/util/optional.h" #include "arrow/util/visibility.h" @@ -110,10 +111,12 @@ class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions { /// Emitted batches will not be ordered. class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { public: - explicit SinkNodeOptions(std::function>()>* generator) - : generator(generator) {} + explicit SinkNodeOptions(std::function>()>* generator, + util::BackpressureOptions backpressure = {}) + : generator(generator), backpressure(std::move(backpressure)) {} std::function>()>* generator; + util::BackpressureOptions backpressure; }; class ARROW_EXPORT SinkNodeConsumer { diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index ac81ed93904..c4ec36490f1 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -237,6 +237,56 @@ TEST(ExecPlanExecution, SourceSink) { } } +TEST(ExecPlanExecution, SinkNodeBackpressure) { + constexpr uint32_t kPauseIfAbove = 4; + constexpr uint32_t kResumeIfBelow = 2; + EXPECT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make()); + PushGenerator> batch_producer; + AsyncGenerator> sink_gen; + util::BackpressureOptions backpressure_options = + util::BackpressureOptions::Make(kResumeIfBelow, kPauseIfAbove); + std::shared_ptr schema_ = schema({field("data", uint32())}); + ARROW_EXPECT_OK(compute::Declaration::Sequence( + { + {"source", SourceNodeOptions(schema_, batch_producer)}, + {"sink", SinkNodeOptions{&sink_gen, backpressure_options}}, + }) + .AddToPlan(plan.get())); + ARROW_EXPECT_OK(plan->StartProducing()); + + EXPECT_OK_AND_ASSIGN(util::optional batch, ExecBatch::Make({MakeScalar(0)})); + ASSERT_TRUE(backpressure_options.toggle->IsOpen()); + + // Should be able to push kPauseIfAbove batches without triggering back pressure + for (uint32_t i = 0; i < kPauseIfAbove; i++) { + batch_producer.producer().Push(batch); + } + SleepABit(); + ASSERT_TRUE(backpressure_options.toggle->IsOpen()); + + // One more batch should trigger back pressure + batch_producer.producer().Push(batch); + BusyWait(10, [&] { return !backpressure_options.toggle->IsOpen(); }); + ASSERT_FALSE(backpressure_options.toggle->IsOpen()); + + // Reading as much as we can while keeping it paused + for (uint32_t i = kPauseIfAbove; i >= kResumeIfBelow; i--) { + ASSERT_FINISHES_OK(sink_gen()); + } + SleepABit(); + ASSERT_FALSE(backpressure_options.toggle->IsOpen()); + + // Reading one more item should open up backpressure + ASSERT_FINISHES_OK(sink_gen()); + BusyWait(10, [&] { return backpressure_options.toggle->IsOpen(); }); + ASSERT_TRUE(backpressure_options.toggle->IsOpen()); + + // Cleanup + batch_producer.producer().Push(IterationEnd>()); + plan->StopProducing(); + ASSERT_FINISHES_OK(plan->finished()); +} + TEST(ExecPlan, ToString) { auto basic_data = MakeBasicBatches(); AsyncGenerator> sink_gen; diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index e1c10edd8a9..1bb2680383c 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -49,22 +49,25 @@ namespace { class SinkNode : public ExecNode { public: SinkNode(ExecPlan* plan, std::vector inputs, - AsyncGenerator>* generator) + AsyncGenerator>* generator, + util::BackpressureOptions backpressure) : ExecNode(plan, std::move(inputs), {"collected"}, {}, /*num_outputs=*/0), - producer_(MakeProducer(generator)) {} + producer_(MakeProducer(generator, std::move(backpressure))) {} static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode")); const auto& sink_options = checked_cast(options); - return plan->EmplaceNode(plan, std::move(inputs), sink_options.generator); + return plan->EmplaceNode(plan, std::move(inputs), sink_options.generator, + sink_options.backpressure); } static PushGenerator>::Producer MakeProducer( - AsyncGenerator>* out_gen) { - PushGenerator> push_gen; + AsyncGenerator>* out_gen, + util::BackpressureOptions backpressure) { + PushGenerator> push_gen(std::move(backpressure)); auto out = push_gen.producer(); *out_gen = std::move(push_gen); return out; @@ -234,8 +237,10 @@ class ConsumingSinkNode : public ExecNode { struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, std::unique_ptr impl, - AsyncGenerator>* generator) - : SinkNode(plan, std::move(inputs), generator), impl_{std::move(impl)} {} + AsyncGenerator>* generator, + util::BackpressureOptions backpressure) + : SinkNode(plan, std::move(inputs), generator, std::move(backpressure)), + impl_{std::move(impl)} {} const char* kind_name() const override { return "OrderBySinkNode"; } @@ -250,7 +255,8 @@ struct OrderBySinkNode final : public SinkNode { OrderByImpl::MakeSort(plan->exec_context(), inputs[0]->output_schema(), sink_options.sort_options)); return plan->EmplaceNode(plan, std::move(inputs), std::move(impl), - sink_options.generator); + sink_options.generator, + sink_options.backpressure); } // A sink node that receives inputs and then compute top_k/bottom_k. @@ -264,7 +270,8 @@ struct OrderBySinkNode final : public SinkNode { OrderByImpl::MakeSelectK(plan->exec_context(), inputs[0]->output_schema(), sink_options.select_k_options)); return plan->EmplaceNode(plan, std::move(inputs), std::move(impl), - sink_options.generator); + sink_options.generator, + sink_options.backpressure); } void InputReceived(ExecNode* input, ExecBatch batch) override { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 433e93172c9..81dc3e55072 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -598,20 +598,23 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get())); AsyncGenerator> sink_gen; + util::BackpressureOptions backpressure = + util::BackpressureOptions::Make(kDefaultBackpressureLow, kDefaultBackpressureHigh); auto exprs = scan_options_->projection.call()->arguments; auto names = checked_cast( scan_options_->projection.call()->options.get()) ->field_names; - RETURN_NOT_OK(compute::Declaration::Sequence( - { - {"scan", ScanNodeOptions{dataset_, scan_options_}}, - {"filter", compute::FilterNodeOptions{scan_options_->filter}}, - {"augmented_project", - compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, - {"sink", compute::SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); + RETURN_NOT_OK( + compute::Declaration::Sequence( + { + {"scan", ScanNodeOptions{dataset_, scan_options_, backpressure.toggle}}, + {"filter", compute::FilterNodeOptions{scan_options_->filter}}, + {"augmented_project", + compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, + {"sink", compute::SinkNodeOptions{&sink_gen, std::move(backpressure)}}, + }) + .AddToPlan(plan.get())); RETURN_NOT_OK(plan->StartProducing()); @@ -1139,6 +1142,7 @@ Result MakeScanNode(compute::ExecPlan* plan, const auto& scan_node_options = checked_cast(options); auto scan_options = scan_node_options.scan_options; auto dataset = scan_node_options.dataset; + const auto& backpressure_toggle = scan_node_options.backpressure_toggle; if (!scan_options->use_async) { return Status::NotImplemented("ScanNodes without asynchrony"); @@ -1201,6 +1205,10 @@ Result MakeScanNode(compute::ExecPlan* plan, return batch; }); + if (backpressure_toggle) { + gen = MakePauseable(gen, backpressure_toggle); + } + auto fields = scan_options->dataset_schema->fields(); for (const auto& aug_field : kAugmentedFields) { fields.push_back(aug_field); diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 9264e9f548a..78746068d87 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -53,6 +53,8 @@ namespace dataset { constexpr int64_t kDefaultBatchSize = 1 << 20; constexpr int32_t kDefaultBatchReadahead = 32; constexpr int32_t kDefaultFragmentReadahead = 8; +constexpr int32_t kDefaultBackpressureHigh = 64; +constexpr int32_t kDefaultBackpressureLow = 32; /// Scan-specific options, which can be changed between scans of the same dataset. struct ARROW_DS_EXPORT ScanOptions { @@ -417,12 +419,16 @@ class ARROW_DS_EXPORT ScannerBuilder { /// ordering for simple ExecPlans. class ARROW_DS_EXPORT ScanNodeOptions : public compute::ExecNodeOptions { public: - explicit ScanNodeOptions(std::shared_ptr dataset, - std::shared_ptr scan_options) - : dataset(std::move(dataset)), scan_options(std::move(scan_options)) {} + explicit ScanNodeOptions( + std::shared_ptr dataset, std::shared_ptr scan_options, + std::shared_ptr backpressure_toggle = NULLPTR) + : dataset(std::move(dataset)), + scan_options(std::move(scan_options)), + backpressure_toggle(std::move(backpressure_toggle)) {} std::shared_ptr dataset; std::shared_ptr scan_options; + std::shared_ptr backpressure_toggle; }; /// @} diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 6235cf2fd50..40a0e005a3f 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -32,12 +32,14 @@ #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/table.h" +#include "arrow/testing/async_test_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/util.h" #include "arrow/util/range.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" using testing::ElementsAre; @@ -740,7 +742,9 @@ INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, class ControlledFragment : public Fragment { public: explicit ControlledFragment(std::shared_ptr schema) - : Fragment(literal(true), std::move(schema)) {} + : Fragment(literal(true), std::move(schema)), + record_batch_generator_(), + tracking_generator_(record_batch_generator_) {} Result Scan(std::shared_ptr options) override { return Status::NotImplemented( @@ -753,9 +757,11 @@ class ControlledFragment : public Fragment { Result ScanBatchesAsync( const std::shared_ptr& options) override { - return record_batch_generator_; + return tracking_generator_; }; + int NumBatchesRead() { return tracking_generator_.num_read(); } + void Finish() { ARROW_UNUSED(record_batch_generator_.producer().Close()); } void DeliverBatch(uint32_t num_rows) { auto batch = ConstantArrayGenerator::Zeroes(num_rows, physical_schema_); @@ -764,6 +770,7 @@ class ControlledFragment : public Fragment { private: PushGenerator> record_batch_generator_; + util::TrackingGenerator> tracking_generator_; }; // TODO(ARROW-8163) Add testing for fragments arriving out of order @@ -963,6 +970,99 @@ TEST_F(TestReordering, ScanBatchesUnordered) { AssertBatchesInOrder(collected, {0, 0, 1, 1, 2}, {0, 2, 3, 1, 4}); } +class TestBackpressure : public ::testing::Test { + protected: + static constexpr int NFRAGMENTS = 10; + static constexpr int NBATCHES = 50; + static constexpr int NROWS = 10; + + FragmentVector MakeFragmentsAndDeliverInitialBatches() { + FragmentVector fragments; + for (int i = 0; i < NFRAGMENTS; i++) { + controlled_fragments_.emplace_back(std::make_shared(schema_)); + fragments.push_back(controlled_fragments_[i]); + // We only emit one batch on the first fragment. This triggers the sequencing + // generator to dig really deep to try and find the second batch + int num_to_emit = NBATCHES; + if (i == 0) { + num_to_emit = 1; + } + for (int j = 0; j < num_to_emit; j++) { + controlled_fragments_[i]->DeliverBatch(NROWS); + } + } + return fragments; + } + + void DeliverAdditionalBatches() { + // Deliver a bunch of batches that should not be read in + for (int i = 1; i < NFRAGMENTS; i++) { + for (int j = 0; j < NBATCHES; j++) { + controlled_fragments_[i]->DeliverBatch(NROWS); + } + } + } + + std::shared_ptr MakeDataset() { + FragmentVector fragments = MakeFragmentsAndDeliverInitialBatches(); + return std::make_shared(schema_, std::move(fragments)); + } + + std::shared_ptr MakeScanner() { + std::shared_ptr dataset = MakeDataset(); + std::shared_ptr options = std::make_shared(); + ScannerBuilder builder(std::move(dataset), options); + ARROW_EXPECT_OK(builder.UseThreads(true)); + ARROW_EXPECT_OK(builder.UseAsync(true)); + ARROW_EXPECT_OK(builder.FragmentReadahead(4)); + EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + return scanner; + } + + int TotalBatchesRead() { + int sum = 0; + for (const auto& controlled_fragment : controlled_fragments_) { + sum += controlled_fragment->NumBatchesRead(); + } + return sum; + } + + void Finish(AsyncGenerator gen) { + for (const auto& controlled_fragment : controlled_fragments_) { + controlled_fragment->Finish(); + } + ASSERT_FINISHES_OK(VisitAsyncGenerator( + gen, [](EnumeratedRecordBatch batch) { return Status::OK(); })); + } + + std::shared_ptr schema_ = schema({field("values", int32())}); + std::vector> controlled_fragments_; +}; + +TEST_F(TestBackpressure, ScanBatchesUnordered) { + std::shared_ptr scanner = MakeScanner(); + EXPECT_OK_AND_ASSIGN(AsyncGenerator gen, + scanner->ScanBatchesUnorderedAsync()); + ASSERT_FINISHES_OK(gen()); + // The exact numbers may be imprecise due to threading but we should pretty quickly read + // up to our backpressure limit and a little above. We should not be able to go too far + // above. + BusyWait(30, [&] { return TotalBatchesRead() >= kDefaultBackpressureHigh; }); + ASSERT_GE(TotalBatchesRead(), kDefaultBackpressureHigh); + // Wait for the thread pool to idle. By this point the scanner should have paused + // itself This helps with timing on slower CI systems where there is only one core and + // the scanner might keep that core until it has scanned all the batches which never + // gives the sink a chance to report it is falling behind. + GetCpuThreadPool()->WaitForIdle(); + DeliverAdditionalBatches(); + + SleepABit(); + // Worst case we read in the entire set of initial batches + ASSERT_LE(TotalBatchesRead(), NBATCHES * (NFRAGMENTS - 1) + 1); + + Finish(std::move(gen)); +} + struct BatchConsumer { explicit BatchConsumer(EnumeratedRecordBatchGenerator generator) : generator(std::move(generator)), next() {} diff --git a/cpp/src/arrow/testing/async_test_util.h b/cpp/src/arrow/testing/async_test_util.h new file mode 100644 index 00000000000..b9f5487ed0d --- /dev/null +++ b/cpp/src/arrow/testing/async_test_util.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/util/async_generator.h" +#include "arrow/util/future.h" + +namespace arrow { +namespace util { + +template +class TrackingGenerator { + public: + explicit TrackingGenerator(AsyncGenerator source) + : state_(std::make_shared(std::move(source))) {} + + Future operator()() { + state_->num_read++; + return state_->source(); + } + + int num_read() { return state_->num_read.load(); } + + private: + struct State { + explicit State(AsyncGenerator source) : source(std::move(source)), num_read(0) {} + + AsyncGenerator source; + std::atomic num_read; + }; + + std::shared_ptr state_; +}; + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index e751c7511d3..d2a2339f5bc 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/util/async_util.h" #include "arrow/util/functional.h" #include "arrow/util/future.h" #include "arrow/util/io_util.h" @@ -785,6 +786,24 @@ class ReadaheadGenerator { template class PushGenerator { struct State { + explicit State(util::BackpressureOptions backpressure) + : backpressure(std::move(backpressure)) {} + + void OpenBackpressureIfFreeUnlocked(util::Mutex::Guard&& guard) { + if (backpressure.toggle && result_q.size() < backpressure.resume_if_below) { + // Open might trigger callbacks so release the lock first + guard.Unlock(); + backpressure.toggle->Open(); + } + } + + void CloseBackpressureIfFullUnlocked() { + if (backpressure.toggle && result_q.size() > backpressure.pause_if_above) { + backpressure.toggle->Close(); + } + } + + util::BackpressureOptions backpressure; util::Mutex mutex; std::deque> result_q; util::optional> consumer_fut; @@ -820,6 +839,7 @@ class PushGenerator { fut.MarkFinished(std::move(result)); } else { state->result_q.push_back(std::move(result)); + state->CloseBackpressureIfFullUnlocked(); } return true; } @@ -868,7 +888,8 @@ class PushGenerator { const std::weak_ptr weak_state_; }; - PushGenerator() : state_(std::make_shared()) {} + explicit PushGenerator(util::BackpressureOptions backpressure = {}) + : state_(std::make_shared(std::move(backpressure))) {} /// Read an item from the queue Future operator()() const { @@ -877,6 +898,7 @@ class PushGenerator { if (!state_->result_q.empty()) { auto fut = Future::MakeFinished(std::move(state_->result_q.front())); state_->result_q.pop_front(); + state_->OpenBackpressureIfFreeUnlocked(std::move(lock)); return fut; } if (state_->finished) { @@ -1645,6 +1667,50 @@ AsyncGenerator MakeCancellable(AsyncGenerator source, StopToken stop_token return CancellableGenerator{std::move(source), std::move(stop_token)}; } +template +struct PauseableGenerator { + public: + PauseableGenerator(AsyncGenerator source, std::shared_ptr toggle) + : state_(std::make_shared(std::move(source), + std::move(toggle))) {} + + Future operator()() { return (*state_)(); } + + private: + struct PauseableGeneratorState + : public std::enable_shared_from_this { + PauseableGeneratorState(AsyncGenerator source, + std::shared_ptr toggle) + : source_(std::move(source)), toggle_(std::move(toggle)) {} + + Future operator()() { + std::shared_ptr self = this->shared_from_this(); + return toggle_->WhenOpen().Then([self] { + util::Mutex::Guard guard = self->mutex_.Lock(); + return self->source_(); + }); + } + + AsyncGenerator source_; + std::shared_ptr toggle_; + util::Mutex mutex_; + }; + std::shared_ptr state_; +}; + +/// \brief Allows an async generator to be paused +/// +/// This generator is NOT async-reentrant and calling it in an async-reentrant fashion +/// may lead to items getting reordered (and potentially truncated if the end token is +/// reordered ahead of valid items) +/// +/// This generator forwards async-reentrant pressure +template +AsyncGenerator MakePauseable(AsyncGenerator source, + std::shared_ptr toggle) { + return PauseableGenerator(std::move(source), std::move(toggle)); +} + template class DefaultIfEmptyGenerator { public: diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index cb269ccb684..22f55d5cb20 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -25,10 +25,12 @@ #include #include "arrow/io/slow.h" +#include "arrow/testing/async_test_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/type_fwd.h" #include "arrow/util/async_generator.h" +#include "arrow/util/async_util.h" #include "arrow/util/optional.h" #include "arrow/util/test_common.h" #include "arrow/util/vector.h" @@ -72,30 +74,6 @@ AsyncGenerator MakeJittery(AsyncGenerator source) { }); } -template -class TrackingGenerator { - public: - explicit TrackingGenerator(AsyncGenerator source) - : state_(std::make_shared(std::move(source))) {} - - Future operator()() { - state_->num_read++; - return state_->source(); - } - - int num_read() { return state_->num_read.load(); } - - private: - struct State { - explicit State(AsyncGenerator source) : source(std::move(source)), num_read(0) {} - - AsyncGenerator source; - std::atomic num_read; - }; - - std::shared_ptr state_; -}; - // Yields items with a small pause between each one from a background thread std::function()> BackgroundAsyncVectorIt( std::vector v, bool sleep = true, int max_q = kDefaultBackgroundMaxQ, @@ -233,6 +211,9 @@ ReentrantCheckerGuard ExpectNotAccessedReentrantly(AsyncGenerator* generat } class GeneratorTestFixture : public ::testing::TestWithParam { + public: + ~GeneratorTestFixture() override = default; + protected: AsyncGenerator MakeSource(const std::vector& items) { std::vector wrapped(items.begin(), items.end()); @@ -386,7 +367,7 @@ TEST(TestAsyncUtil, MapAsync) { TEST(TestAsyncUtil, MapReentrant) { std::vector input = {1, 2}; auto source = AsyncVectorIt(input); - TrackingGenerator tracker(std::move(source)); + util::TrackingGenerator tracker(std::move(source)); source = MakeTransferredGenerator(AsyncGenerator(tracker), internal::GetCpuThreadPool()); @@ -590,7 +571,7 @@ TEST_P(MergedGeneratorTestFixture, MergedLimitedSubscriptions) { auto gen = AsyncVectorIt>( {MakeSource({1, 2}), MakeSource({3, 4}), MakeSource({5, 6, 7, 8}), MakeSource({9, 10, 11, 12})}); - TrackingGenerator> tracker(std::move(gen)); + util::TrackingGenerator> tracker(std::move(gen)); auto merged = MakeMergedGenerator(AsyncGenerator>(tracker), 2); SleepABit(); @@ -1263,6 +1244,91 @@ TEST_P(EnumeratorTestFixture, Error) { INSTANTIATE_TEST_SUITE_P(EnumeratedTests, EnumeratorTestFixture, ::testing::Values(false, true)); +class PauseableTestFixture : public GeneratorTestFixture { + public: + ~PauseableTestFixture() override { generator_.producer().Close(); } + + protected: + PauseableTestFixture() : toggle_(std::make_shared()) { + sink_.clear(); + counter_ = 0; + AsyncGenerator source = GetSource(); + AsyncGenerator pauseable = MakePauseable(std::move(source), toggle_); + finished_ = VisitAsyncGenerator(std::move(pauseable), [this](TestInt val) { + std::lock_guard lg(mutex_); + sink_.push_back(val.value); + return Status::OK(); + }); + } + + void Emit() { generator_.producer().Push(counter_++); } + + void Pause() { toggle_->Close(); } + + void Resume() { toggle_->Open(); } + + int NumCollected() { + std::lock_guard lg(mutex_); + // The push generator can desequence things so we check and don't count gaps. It's + // a bit inefficient but good enough for this test + int count = 0; + for (std::size_t i = 0; i < sink_.size(); i++) { + int prev_count = count; + for (std::size_t j = 0; j < sink_.size(); j++) { + if (sink_[j] == count) { + count++; + break; + } + } + if (prev_count == count) { + break; + } + } + return count; + } + + void AssertAtLeastNCollected(int target_count) { + BusyWait(10, [this, target_count] { return NumCollected() >= target_count; }); + ASSERT_GE(NumCollected(), target_count); + } + + void AssertNoMoreThanNCollected(int target_count) { + ASSERT_LE(NumCollected(), target_count); + } + + AsyncGenerator GetSource() { + const auto& source = static_cast>(generator_); + if (IsSlow()) { + return SlowdownABit(source); + } else { + return source; + } + } + + std::mutex mutex_; + int counter_ = 0; + PushGenerator generator_; + std::shared_ptr toggle_; + std::vector sink_; + Future<> finished_; +}; + +INSTANTIATE_TEST_SUITE_P(PauseableTests, PauseableTestFixture, + ::testing::Values(false, true)); + +TEST_P(PauseableTestFixture, PauseBasic) { + Emit(); + Pause(); + // This emit was asked for before the pause so it will go through + Emit(); + AssertNoMoreThanNCollected(2); + // This emit should be blocked by the pause + Emit(); + AssertNoMoreThanNCollected(2); + Resume(); + AssertAtLeastNCollected(3); +} + class SequencerTestFixture : public GeneratorTestFixture { protected: void RandomShuffle(std::vector& values) { @@ -1361,6 +1427,21 @@ TEST_P(SequencerTestFixture, SequenceError) { } } +TEST_P(SequencerTestFixture, Readahead) { + AsyncGenerator original = MakeSource({4, 2, 0, 6}); + util::TrackingGenerator tracker(original); + AsyncGenerator sequenced = MakeSequencingGenerator( + static_cast>(tracker), cmp_, is_next_, TestInt(-2)); + ASSERT_FINISHES_OK_AND_EQ(TestInt(0), sequenced()); + ASSERT_EQ(3, tracker.num_read()); + ASSERT_FINISHES_OK_AND_EQ(TestInt(2), sequenced()); + ASSERT_EQ(3, tracker.num_read()); + ASSERT_FINISHES_OK_AND_EQ(TestInt(4), sequenced()); + ASSERT_EQ(3, tracker.num_read()); + ASSERT_FINISHES_OK_AND_EQ(TestInt(6), sequenced()); + ASSERT_EQ(4, tracker.num_read()); +} + TEST_P(SequencerTestFixture, SequenceStress) { constexpr int NITEMS = 100; for (auto task_index = 0; task_index < GetNumItersForStress(); task_index++) { diff --git a/cpp/src/arrow/util/async_util.cc b/cpp/src/arrow/util/async_util.cc index 9407684bdda..f5b9bdcbe6c 100644 --- a/cpp/src/arrow/util/async_util.cc +++ b/cpp/src/arrow/util/async_util.cc @@ -162,5 +162,45 @@ bool SerializedAsyncTaskGroup::TryDrainUnlocked() { return false; } +Future<> AsyncToggle::WhenOpen() { + util::Mutex::Guard guard = mutex_.Lock(); + return when_open_; +} + +void AsyncToggle::Open() { + util::Mutex::Guard guard = mutex_.Lock(); + if (!closed_) { + return; + } + closed_ = false; + Future<> to_finish = when_open_; + guard.Unlock(); + to_finish.MarkFinished(); +} + +void AsyncToggle::Close() { + util::Mutex::Guard guard = mutex_.Lock(); + if (closed_) { + return; + } + closed_ = true; + when_open_ = Future<>::Make(); +} + +bool AsyncToggle::IsOpen() { + util::Mutex::Guard guard = mutex_.Lock(); + return !closed_; +} + +BackpressureOptions BackpressureOptions::Make(uint32_t resume_if_below, + uint32_t pause_if_above) { + auto toggle = std::make_shared(); + return BackpressureOptions{std::move(toggle), resume_if_below, pause_if_above}; +} + +BackpressureOptions BackpressureOptions::NoBackpressure() { + return BackpressureOptions(); +} + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/async_util.h b/cpp/src/arrow/util/async_util.h index daa6bad8cee..29b21683099 100644 --- a/cpp/src/arrow/util/async_util.h +++ b/cpp/src/arrow/util/async_util.h @@ -195,5 +195,64 @@ class ARROW_EXPORT SerializedAsyncTaskGroup { Future<> processing_; }; +class ARROW_EXPORT AsyncToggle { + public: + /// Get a future that will complete when the toggle next becomes open + /// + /// If the toggle is open this returns immediately + /// If the toggle is closed this future will be unfinished until the next call to Open + Future<> WhenOpen(); + /// \brief Close the toggle + /// + /// After this call any call to WhenOpen will be delayed until the next open + void Close(); + /// \brief Open the toggle + /// + /// Note: This call may complete a future, triggering any callbacks, and generally + /// should not be done while holding any locks. + /// + /// Note: If Open is called from multiple threads it could lead to a situation where + /// callbacks from the second open finish before callbacks on the first open. + /// + /// All current waiters will be released to enter, even if another close call + /// quickly follows + void Open(); + + /// \brief Return true if the toggle is currently open + bool IsOpen(); + + private: + Future<> when_open_ = Future<>::MakeFinished(); + bool closed_ = false; + util::Mutex mutex_; +}; + +/// \brief Options to control backpressure behavior +struct ARROW_EXPORT BackpressureOptions { + /// \brief Create default options that perform no backpressure + BackpressureOptions() : toggle(NULLPTR), resume_if_below(0), pause_if_above(0) {} + /// \brief Create options that will perform backpressure + /// + /// \param toggle A toggle to be shared between the producer and consumer + /// \param resume_if_below The producer should resume producing if the backpressure + /// queue has fewer than resume_if_below items. + /// \param pause_if_above The producer should pause producing if the backpressure + /// queue has more than pause_if_above items + BackpressureOptions(std::shared_ptr toggle, uint32_t resume_if_below, + uint32_t pause_if_above) + : toggle(std::move(toggle)), + resume_if_below(resume_if_below), + pause_if_above(pause_if_above) {} + + static BackpressureOptions Make(uint32_t resume_if_below = 32, + uint32_t pause_if_above = 64); + + static BackpressureOptions NoBackpressure(); + + std::shared_ptr toggle; + uint32_t resume_if_below; + uint32_t pause_if_above; +}; + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 758295d01ed..37132fe1a9c 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -121,6 +121,7 @@ struct ThreadPool::State { std::mutex mutex_; std::condition_variable cv_; std::condition_variable cv_shutdown_; + std::condition_variable cv_idle_; std::list workers_; // Trashcan for finished threads @@ -182,7 +183,9 @@ static void WorkerLoop(std::shared_ptr state, ARROW_UNUSED(std::move(task)); // release resources before waiting for lock lock.lock(); } - state->tasks_queued_or_running_--; + if (ARROW_PREDICT_FALSE(--state->tasks_queued_or_running_ == 0)) { + state->cv_idle_.notify_all(); + } } // Now either the queue is empty *or* a quick shutdown was requested if (state->please_shutdown_ || should_secede()) { @@ -209,6 +212,11 @@ static void WorkerLoop(std::shared_ptr state, } } +void ThreadPool::WaitForIdle() { + std::unique_lock lk(state_->mutex_); + state_->cv_idle_.wait(lk, [this] { return state_->tasks_queued_or_running_ == 0; }); +} + ThreadPool::ThreadPool() : sp_state_(std::make_shared()), state_(sp_state_.get()), diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 9ac8e36a3d8..4ed908d6f29 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -341,6 +341,11 @@ class ARROW_EXPORT ThreadPool : public Executor { // tasks are finished. Status Shutdown(bool wait = true); + // Wait for the thread pool to become idle + // + // This is useful for sequencing tests + void WaitForIdle(); + struct State; protected: