Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/src/arrow/util/async_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ class SerialReadaheadGenerator {
std::shared_ptr<State> state_;
};

/// \see MakeFromFuture
template <typename T>
class FutureFirstGenerator {
public:
Expand Down Expand Up @@ -669,6 +670,12 @@ class FutureFirstGenerator {
std::shared_ptr<State> state_;
};

/// \brief Transforms a Future<AsyncGenerator<T>> into an AsyncGenerator<T>
/// that waits for the future to complete as part of the first item.
///
/// This generator is not async-reentrant (even if the generator yielded by future is)
///
/// This generator does not queue
template <typename T>
AsyncGenerator<T> MakeFromFuture(Future<AsyncGenerator<T>> future) {
return FutureFirstGenerator<T>(std::move(future));
Expand Down
35 changes: 29 additions & 6 deletions cpp/src/arrow/util/async_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ class GeneratorTestFixture : public ::testing::TestWithParam<bool> {
AsyncGenerator<TestInt> MakeSource(const std::vector<TestInt>& items) {
std::vector<TestInt> wrapped(items.begin(), items.end());
auto gen = AsyncVectorIt(std::move(wrapped));
bool slow = GetParam();
if (slow) {
if (IsSlow()) {
return SlowdownABit(std::move(gen));
}
return gen;
Expand All @@ -233,22 +232,22 @@ class GeneratorTestFixture : public ::testing::TestWithParam<bool> {
AsyncGenerator<TestInt> gen = [] {
return Future<TestInt>::MakeFinished(Status::Invalid("XYZ"));
};
bool slow = GetParam();
if (slow) {
if (IsSlow()) {
return SlowdownABit(std::move(gen));
}
return gen;
}

int GetNumItersForStress() {
bool slow = GetParam();
// Run fewer trials for the slow case since they take longer
if (slow) {
if (IsSlow()) {
return 10;
} else {
return 100;
}
}

bool IsSlow() { return GetParam(); }
};

template <typename T>
Expand Down Expand Up @@ -461,6 +460,30 @@ TEST(TestAsyncUtil, Concatenated) {
AssertAsyncGeneratorMatch(expected, concat);
}

class FromFutureFixture : public GeneratorTestFixture {};

TEST_P(FromFutureFixture, Basic) {
auto source = Future<std::vector<TestInt>>::MakeFinished(RangeVector(3));
if (IsSlow()) {
source = SleepABitAsync().Then(
[](...) -> Result<std::vector<TestInt>> { return RangeVector(3); });
}
auto slow = IsSlow();
auto to_gen = source.Then([slow](const std::vector<TestInt>& vec) {
auto vec_gen = MakeVectorGenerator(vec);
if (slow) {
return SlowdownABit(std::move(vec_gen));
}
return vec_gen;
});
auto gen = MakeFromFuture(std::move(to_gen));
auto collected = CollectAsyncGenerator(std::move(gen));
ASSERT_FINISHES_OK_AND_EQ(RangeVector(3), collected);
}

INSTANTIATE_TEST_SUITE_P(FromFutureTests, FromFutureFixture,
::testing::Values(false, true));

class MergedGeneratorTestFixture : public GeneratorTestFixture {};

TEST_P(MergedGeneratorTestFixture, Merged) {
Expand Down