diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index db98243267b..a95efb2365b 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1063,6 +1063,80 @@ AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> so return MergedGenerator(std::move(source), 1); } +template +struct Enumerated { + T value; + int index; + bool last; +}; + +template +struct IterationTraits> { + static Enumerated End() { return Enumerated{IterationEnd(), -1, false}; } + static bool IsEnd(const Enumerated& val) { return val.index < 0; } +}; + +/// \see MakeEnumeratedGenerator +template +class EnumeratingGenerator { + public: + EnumeratingGenerator(AsyncGenerator source, T initial_value) + : state_(std::make_shared(std::move(source), std::move(initial_value))) {} + + Future> operator()() { + if (state_->finished) { + return AsyncGeneratorEnd>(); + } else { + auto state = state_; + return state->source().Then([state](const T& next) { + auto finished = IsIterationEnd(next); + auto prev = Enumerated{state->prev_value, state->prev_index, finished}; + state->prev_value = next; + state->prev_index++; + state->finished = finished; + return prev; + }); + } + } + + private: + struct State { + State(AsyncGenerator source, T initial_value) + : source(std::move(source)), prev_value(std::move(initial_value)), prev_index(0) { + finished = IsIterationEnd(prev_value); + } + + AsyncGenerator source; + T prev_value; + int prev_index; + bool finished; + }; + + std::shared_ptr state_; +}; + +/// Wraps items from a source generator with positional information +/// +/// When used with MakeMergedGenerator and MakeSequencingGenerator this allows items to be +/// processed in a "first-available" fashion and later resequenced which can reduce the +/// impact of sources with erratic performance (e.g. a filesystem where some items may +/// take longer to read than others). +/// +/// TODO(ARROW-12371) Would require this generator be async-reentrant +/// +/// \see MakeSequencingGenerator for an example of putting items back in order +/// +/// This generator is not async-reentrant +/// +/// This generator buffers one item (so it knows which item is the last item) +template +AsyncGenerator> MakeEnumeratedGenerator(AsyncGenerator source) { + return FutureFirstGenerator>( + source().Then([source](const T& initial_value) -> AsyncGenerator> { + return EnumeratingGenerator(std::move(source), initial_value); + })); +} + /// \see MakeTransferredGenerator template class TransferringGenerator { diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 518422de586..62ba35e2f7e 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -229,6 +229,8 @@ class GeneratorTestFixture : public ::testing::TestWithParam { return gen; } + AsyncGenerator MakeEmptySource() { return MakeSource({}); } + AsyncGenerator MakeFailingSource() { AsyncGenerator gen = [] { return Future::MakeFinished(Status::Invalid("XYZ")); @@ -1017,6 +1019,50 @@ TEST(TestAsyncUtil, ReadaheadFailed) { ASSERT_TRUE(IsIterationEnd(definitely_last)); } +class EnumeratorTestFixture : public GeneratorTestFixture { + protected: + void AssertEnumeratedCorrectly(AsyncGenerator>& gen, + int num_items) { + auto collected = CollectAsyncGenerator(gen); + ASSERT_FINISHES_OK_AND_ASSIGN(auto items, collected); + EXPECT_EQ(num_items, items.size()); + + for (const auto& item : items) { + ASSERT_EQ(item.index, item.value.value); + bool last = item.index == num_items - 1; + ASSERT_EQ(last, item.last); + } + AssertGeneratorExhausted(gen); + } +}; + +TEST_P(EnumeratorTestFixture, Basic) { + constexpr int NITEMS = 100; + + auto source = MakeSource(RangeVector(NITEMS)); + auto enumerated = MakeEnumeratedGenerator(std::move(source)); + + AssertEnumeratedCorrectly(enumerated, NITEMS); +} + +TEST_P(EnumeratorTestFixture, Empty) { + auto source = MakeEmptySource(); + auto enumerated = MakeEnumeratedGenerator(std::move(source)); + AssertGeneratorExhausted(enumerated); +} + +TEST_P(EnumeratorTestFixture, Error) { + auto source = FailsAt(MakeSource({1, 2, 3}), 1); + auto enumerated = MakeEnumeratedGenerator(std::move(source)); + + // Even though the first item finishes ok the enumerator buffers it. The error then + // takes priority over the buffered result. + ASSERT_FINISHES_AND_RAISES(Invalid, enumerated()); +} + +INSTANTIATE_TEST_SUITE_P(EnumeratedTests, EnumeratorTestFixture, + ::testing::Values(false, true)); + class SequencerTestFixture : public GeneratorTestFixture { protected: void RandomShuffle(std::vector& values) {