diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 86f14de46fd..8437c75ae1c 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -360,6 +360,64 @@ class WriteQueue { std::shared_ptr schema_; }; +struct WriteState { + explicit WriteState(FileSystemDatasetWriteOptions write_options) + : write_options(std::move(write_options)) {} + + FileSystemDatasetWriteOptions write_options; + util::Mutex mutex; + std::unordered_map> queues; +}; + +Status WriteNextBatch(WriteState& state, const std::shared_ptr& scan_task, + std::shared_ptr batch) { + ARROW_ASSIGN_OR_RAISE(auto groups, state.write_options.partitioning->Partition(batch)); + batch.reset(); // drop to hopefully conserve memory + + if (groups.batches.size() > static_cast(state.write_options.max_partitions)) { + return Status::Invalid("Fragment would be written into ", groups.batches.size(), + " partitions. This exceeds the maximum of ", + state.write_options.max_partitions); + } + + std::unordered_set need_flushed; + for (size_t i = 0; i < groups.batches.size(); ++i) { + auto partition_expression = and_(std::move(groups.expressions[i]), + scan_task->fragment()->partition_expression()); + auto batch = std::move(groups.batches[i]); + + ARROW_ASSIGN_OR_RAISE(auto part, + state.write_options.partitioning->Format(partition_expression)); + + WriteQueue* queue; + { + // lookup the queue to which batch should be appended + auto queues_lock = state.mutex.Lock(); + + queue = internal::GetOrInsertGenerated( + &state.queues, std::move(part), + [&](const std::string& emplaced_part) { + // lookup in `queues` also failed, + // generate a new WriteQueue + size_t queue_index = state.queues.size() - 1; + + return internal::make_unique(emplaced_part, queue_index, + batch->schema()); + }) + ->second.get(); + } + + queue->Push(std::move(batch)); + need_flushed.insert(queue); + } + + // flush all touched WriteQueues + for (auto queue : need_flushed) { + RETURN_NOT_OK(queue->Flush(state.write_options)); + } + return Status::OK(); +} + } // namespace Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, @@ -382,6 +440,7 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments()); ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector()); ScanTaskVector scan_tasks; + std::vector> scan_futs; for (const auto& fragment : fragments) { auto options = std::make_shared(*scanner->options()); @@ -399,68 +458,35 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio // to a WriteQueue which flushes batches into that partition's output file. In principle // any thread could produce a batch for any partition, so each task alternates between // pushing batches and flushing them to disk. - util::Mutex queues_mutex; - std::unordered_map> queues; + WriteState state(write_options); for (const auto& scan_task : scan_tasks) { - task_group->Append([&, scan_task] { - ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute()); - - for (auto maybe_batch : batches) { - ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch); - ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch)); - batch.reset(); // drop to hopefully conserve memory - - if (groups.batches.size() > static_cast(write_options.max_partitions)) { - return Status::Invalid("Fragment would be written into ", groups.batches.size(), - " partitions. This exceeds the maximum of ", - write_options.max_partitions); - } - - std::unordered_set need_flushed; - for (size_t i = 0; i < groups.batches.size(); ++i) { - auto partition_expression = and_(std::move(groups.expressions[i]), - scan_task->fragment()->partition_expression()); - auto batch = std::move(groups.batches[i]); - - ARROW_ASSIGN_OR_RAISE(auto part, - write_options.partitioning->Format(partition_expression)); - - WriteQueue* queue; - { - // lookup the queue to which batch should be appended - auto queues_lock = queues_mutex.Lock(); - - queue = internal::GetOrInsertGenerated( - &queues, std::move(part), - [&](const std::string& emplaced_part) { - // lookup in `queues` also failed, - // generate a new WriteQueue - size_t queue_index = queues.size() - 1; - - return internal::make_unique( - emplaced_part, queue_index, batch->schema()); - }) - ->second.get(); - } - - queue->Push(std::move(batch)); - need_flushed.insert(queue); - } + if (scan_task->supports_async()) { + ARROW_ASSIGN_OR_RAISE(auto batches_gen, scan_task->ExecuteAsync()); + std::function batch)> batch_visitor = + [&, scan_task](std::shared_ptr batch) { + return WriteNextBatch(state, scan_task, std::move(batch)); + }; + scan_futs.push_back(VisitAsyncGenerator(batches_gen, batch_visitor)); + } else { + task_group->Append([&, scan_task] { + ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute()); - // flush all touched WriteQueues - for (auto queue : need_flushed) { - RETURN_NOT_OK(queue->Flush(write_options)); + for (auto maybe_batch : batches) { + ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch); + RETURN_NOT_OK(WriteNextBatch(state, scan_task, std::move(batch))); } - } - return Status::OK(); - }); + return Status::OK(); + }); + } } RETURN_NOT_OK(task_group->Finish()); + auto scan_futs_all_done = AllComplete(scan_futs); + RETURN_NOT_OK(scan_futs_all_done.status()); task_group = scanner->options()->TaskGroup(); - for (const auto& part_queue : queues) { + for (const auto& part_queue : state.queues) { task_group->Append([&] { return part_queue.second->writer()->Finish(); }); } return task_group->Finish(); diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 9c613c00aff..e4e7167aa75 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -285,7 +285,7 @@ class ARROW_DS_EXPORT FileWriter { Status Write(RecordBatchReader* batches); - Status Finish(); + virtual Status Finish(); const std::shared_ptr& format() const { return options_->format(); } const std::shared_ptr& schema() const { return schema_; } diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index e736d06753b..b55c23dfdef 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -34,6 +34,7 @@ #include "arrow/io/compressed.h" #include "arrow/result.h" #include "arrow/type.h" +#include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" @@ -42,6 +43,7 @@ namespace dataset { using internal::checked_cast; using internal::checked_pointer_cast; +using RecordBatchGenerator = AsyncGenerator>; Result> GetColumnNames( const csv::ParseOptions& parse_options, util::string_view first_block, @@ -110,35 +112,47 @@ static inline Result GetReadOptions( return read_options; } -static inline Result> OpenReader( +static inline Future> OpenReaderAsync( const FileSource& source, const CsvFileFormat& format, const std::shared_ptr& scan_options = nullptr, MemoryPool* pool = default_memory_pool()) { ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options)); - util::string_view first_block; ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed()); ARROW_ASSIGN_OR_RAISE( input, io::BufferedInputStream::Create(reader_options.block_size, default_memory_pool(), std::move(input))); - ARROW_ASSIGN_OR_RAISE(first_block, input->Peek(reader_options.block_size)); - - const auto& parse_options = format.parse_options; - auto convert_options = csv::ConvertOptions::Defaults(); - if (scan_options != nullptr) { - ARROW_ASSIGN_OR_RAISE(convert_options, - GetConvertOptions(format, scan_options, first_block, pool)); - } - auto maybe_reader = - csv::StreamingReader::Make(io::IOContext(pool), std::move(input), reader_options, - parse_options, convert_options); - if (!maybe_reader.ok()) { - return maybe_reader.status().WithMessage("Could not open CSV input source '", - source.path(), "': ", maybe_reader.status()); - } + auto peek_fut = DeferNotOk(input->io_context().executor()->Submit( + [input, reader_options] { return input->Peek(reader_options.block_size); })); + + return peek_fut.Then([=](const util::string_view& first_block) + -> Future> { + const auto& parse_options = format.parse_options; + auto convert_options = csv::ConvertOptions::Defaults(); + if (scan_options != nullptr) { + ARROW_ASSIGN_OR_RAISE(convert_options, + GetConvertOptions(format, scan_options, first_block, pool)); + } + + return csv::StreamingReader::MakeAsync(io::default_io_context(), std::move(input), + reader_options, parse_options, convert_options) + .Then( + [](const std::shared_ptr& maybe_reader) + -> Result> { return maybe_reader; }, + [source](const Status& err) -> Result> { + return err.WithMessage("Could not open CSV input source '", source.path(), + "': ", err); + }); + }); +} - return std::move(maybe_reader).ValueOrDie(); +static inline Result> OpenReader( + const FileSource& source, const CsvFileFormat& format, + const std::shared_ptr& scan_options = nullptr, + MemoryPool* pool = default_memory_pool()) { + auto open_reader_fut = OpenReaderAsync(source, format, scan_options, pool); + return open_reader_fut.result(); } /// \brief A ScanTask backed by an Csv file. @@ -152,9 +166,19 @@ class CsvScanTask : public ScanTask { source_(fragment->source()) {} Result Execute() override { - ARROW_ASSIGN_OR_RAISE(auto reader, - OpenReader(source_, *format_, options(), options()->pool)); - return IteratorFromReader(std::move(reader)); + ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync()); + return MakeGeneratorIterator(std::move(gen)); + } + + bool supports_async() const override { return true; } + + Result ExecuteAsync() override { + auto reader_fut = OpenReaderAsync(source_, *format_, options(), options()->pool); + auto generator_fut = reader_fut.Then( + [](const std::shared_ptr& reader) -> RecordBatchGenerator { + return [reader]() { return reader->ReadNextAsync(); }; + }); + return MakeFromFuture(generator_fut); } private: diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index c7ce5154d0a..fdbb4512758 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -249,6 +249,35 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) { }); } +class TestFilesystemDatasetNestedParallelism : public NestedParallelismMixin {}; + +TEST_F(TestFilesystemDatasetNestedParallelism, Write) { + constexpr int NUM_BATCHES = 32; + RecordBatchVector batches; + for (int i = 0; i < NUM_BATCHES; i++) { + batches.push_back(ConstantArrayGenerator::Zeroes(/*size=*/1, schema_)); + } + auto dataset = std::make_shared(schema_, std::move(batches)); + ScannerBuilder builder{dataset, options_}; + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto output_dir, TemporaryDir::Make("nested-parallel-dataset")); + + auto format = std::make_shared(); + auto rows_written = std::make_shared>(0); + std::shared_ptr file_write_options = + std::make_shared(rows_written); + FileSystemDatasetWriteOptions dataset_write_options; + dataset_write_options.file_write_options = file_write_options; + dataset_write_options.basename_template = "{i}"; + dataset_write_options.partitioning = std::make_shared(schema({})); + dataset_write_options.base_dir = output_dir->path().ToString(); + dataset_write_options.filesystem = std::make_shared(); + + ASSERT_OK(FileSystemDataset::Write(dataset_write_options, scanner)); + ASSERT_EQ(NUM_BATCHES, rows_written->load()); +} + // Tests of subtree pruning struct TestPathTree { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index dee96ceb836..2258a10d141 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -61,6 +61,12 @@ Result InMemoryScanTask::Execute() { return MakeVectorIterator(record_batches_); } +Result ScanTask::ExecuteAsync() { + return Status::NotImplemented("Async is not implemented for this scan task yet"); +} + +bool ScanTask::supports_async() const { return false; } + Result Scanner::GetFragments() { if (fragment_ != nullptr) { return MakeVectorIterator(FragmentVector{fragment_}); @@ -203,19 +209,31 @@ Result> Scanner::ToTable() { auto state = std::make_shared(); size_t scan_task_id = 0; + std::vector> scan_futures; for (auto maybe_scan_task : scan_task_it) { ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task); auto id = scan_task_id++; - task_group->Append([state, id, scan_task] { - ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute()); - ARROW_ASSIGN_OR_RAISE(auto local, batch_it.ToVector()); - state->Emplace(std::move(local), id); - return Status::OK(); - }); + if (scan_task->supports_async()) { + ARROW_ASSIGN_OR_RAISE(auto scan_gen, scan_task->ExecuteAsync()); + auto scan_fut = CollectAsyncGenerator(std::move(scan_gen)) + .Then([state, id](const RecordBatchVector& rbs) { + state->Emplace(rbs, id); + }); + scan_futures.push_back(std::move(scan_fut)); + } else { + task_group->Append([state, id, scan_task] { + ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute()); + ARROW_ASSIGN_OR_RAISE(auto local, batch_it.ToVector()); + state->Emplace(std::move(local), id); + return Status::OK(); + }); + } } + // Wait for all async tasks to complete, or the first error + RETURN_NOT_OK(AllComplete(scan_futures).status()); - // Wait for all tasks to complete, or the first error. + // Wait for all sync tasks to complete, or the first error. RETURN_NOT_OK(task_group->Finish()); return Table::FromRecordBatches(scan_options_->projected_schema, diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index df5f7954afe..c3cce00d8c5 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -31,9 +31,11 @@ #include "arrow/dataset/visibility.h" #include "arrow/memory_pool.h" #include "arrow/type_fwd.h" +#include "arrow/util/async_generator.h" #include "arrow/util/type_fwd.h" namespace arrow { +using RecordBatchGenerator = AsyncGenerator>; namespace dataset { constexpr int64_t kDefaultBatchSize = 1 << 20; @@ -101,6 +103,8 @@ class ARROW_DS_EXPORT ScanTask { /// resulting from the Scan. Execution semantics are encapsulated in the /// particular ScanTask implementation virtual Result Execute() = 0; + virtual Result ExecuteAsync(); + virtual bool supports_async() const; virtual ~ScanTask() = default; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index e666d251cd1..3101be477fd 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -36,6 +36,8 @@ using internal::checked_cast; namespace dataset { +// TODO(ARROW-7001) This synchronous version is no longer needed, can use async version +// regardless of sync/async of source inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression filter, MemoryPool* pool) { return MakeMaybeMapIterator( @@ -60,6 +62,38 @@ inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression std::move(it)); } +inline Result> DoFilterRecordBatch( + const Expression& filter, MemoryPool* pool, const std::shared_ptr& in) { + compute::ExecContext exec_context{pool}; + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(filter, Datum(in), &exec_context)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return std::move(in); + } + return in->Slice(0, 0); + } + + ARROW_ASSIGN_OR_RAISE( + Datum filtered, + compute::Filter(in, mask, compute::FilterOptions::Defaults(), &exec_context)); + return filtered.record_batch(); +} + +inline RecordBatchGenerator FilterRecordBatch(RecordBatchGenerator rbs, Expression filter, + MemoryPool* pool) { + // TODO(ARROW-7001) This changes to auto + std::function>(const std::shared_ptr&)> + mapper = [=](const std::shared_ptr& in) { + return DoFilterRecordBatch(filter, pool, in); + }; + return MakeMappedGenerator(std::move(rbs), mapper); +} + +// TODO(ARROW-7001) This synchronous version is no longer needed, all branches use async +// version inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, Expression projection, MemoryPool* pool) { return MakeMaybeMapIterator( @@ -83,6 +117,35 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, std::move(it)); } +inline Result> DoProjectRecordBatch( + const Expression& projection, MemoryPool* pool, + const std::shared_ptr& in) { + compute::ExecContext exec_context{pool}; + ARROW_ASSIGN_OR_RAISE(Datum projected, + ExecuteScalarExpression(projection, Datum(in), &exec_context)); + DCHECK_EQ(projected.type()->id(), Type::STRUCT); + if (projected.shape() == ValueDescr::SCALAR) { + // Only virtual columns are projected. Broadcast to an array + ARROW_ASSIGN_OR_RAISE(projected, + MakeArrayFromScalar(*projected.scalar(), in->num_rows(), pool)); + } + + ARROW_ASSIGN_OR_RAISE(auto out, + RecordBatch::FromStructArray(projected.array_as())); + + return out->ReplaceSchemaMetadata(in->schema()->metadata()); +} + +inline RecordBatchGenerator ProjectRecordBatch(RecordBatchGenerator rbs, + Expression projection, MemoryPool* pool) { + // TODO(ARROW-7001) This changes to auto + std::function>(const std::shared_ptr&)> + mapper = [=](const std::shared_ptr& in) { + return DoProjectRecordBatch(projection, pool, in); + }; + return MakeMappedGenerator(std::move(rbs), mapper); +} + class FilterAndProjectScanTask : public ScanTask { public: explicit FilterAndProjectScanTask(std::shared_ptr task, Expression partition) @@ -90,7 +153,9 @@ class FilterAndProjectScanTask : public ScanTask { task_(std::move(task)), partition_(std::move(partition)) {} - Result Execute() override { + bool supports_async() const override { return task_->supports_async(); } + + Result ExecuteSync() { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, @@ -106,6 +171,36 @@ class FilterAndProjectScanTask : public ScanTask { options_->pool); } + Result Execute() override { + if (task_->supports_async()) { + ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync()); + return MakeGeneratorIterator(std::move(gen)); + } else { + return ExecuteSync(); + } + } + + Result ExecuteAsync() override { + if (!task_->supports_async()) { + return Status::Invalid( + "ExecuteAsync should not have been called on FilterAndProjectScanTask if the " + "source task did not support async"); + } + ARROW_ASSIGN_OR_RAISE(auto gen, task_->ExecuteAsync()); + + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(options()->filter, partition_)); + + ARROW_ASSIGN_OR_RAISE(Expression simplified_projection, + SimplifyWithGuarantee(options()->projection, partition_)); + + RecordBatchGenerator filter_gen = + FilterRecordBatch(std::move(gen), simplified_filter, options_->pool); + + return ProjectRecordBatch(std::move(filter_gen), simplified_projection, + options_->pool); + } + private: std::shared_ptr task_; Expression partition_; diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 66b1edff568..eec8ed21668 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -151,6 +151,21 @@ TEST_F(TestScanner, ToTable) { AssertTablesEqual(*expected, *actual); } +class TestScannerNestedParallelism : public NestedParallelismMixin {}; + +TEST_F(TestScannerNestedParallelism, Scan) { + constexpr int NUM_BATCHES = 32; + RecordBatchVector batches; + for (int i = 0; i < NUM_BATCHES; i++) { + batches.push_back(ConstantArrayGenerator::Zeroes(/*size=*/1, schema_)); + } + auto dataset = std::make_shared(schema_, std::move(batches)); + ScannerBuilder builder{dataset, options_}; + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); + ASSERT_EQ(table->num_rows(), NUM_BATCHES); +} + class TestScannerBuilder : public ::testing::Test { void SetUp() override { DatasetVector sources; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 6a4c1eb8d13..86bb14b038d 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -780,5 +780,156 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { std::shared_ptr scan_options_; }; +// These test cases will run on a thread pool with 1 thread. Any illegal (non-async) +// nested parallelism should deadlock the test +class NestedParallelismMixin : public ::testing::Test { + protected: + static void SetUpTestSuite() {} + + void TearDown() override { + if (old_capacity_ > 0) { + ASSERT_OK(internal::GetCpuThreadPool()->SetCapacity(old_capacity_)); + } + } + + void SetUp() override { + old_capacity_ = internal::GetCpuThreadPool()->GetCapacity(); + ASSERT_OK(internal::GetCpuThreadPool()->SetCapacity(1)); + schema_ = schema({field("i32", int32())}); + options_ = std::make_shared(); + options_->dataset_schema = schema_; + options_->use_threads = true; + } + + class NestedParallelismScanTask : public ScanTask { + public: + explicit NestedParallelismScanTask(std::shared_ptr target) + : ScanTask(target->options(), target->fragment()), target_(std::move(target)) {} + virtual ~NestedParallelismScanTask() = default; + + Result Execute() override { + // We could just return an invalid status here but this way it is easy to verify the + // test is checking what it is supposed to be checking by just changing + // supports_async() to false (will deadlock) + ADD_FAILURE() << "NestedParallelismScanTask::Execute should never be called. You " + "should be deadlocked right now"; + ARROW_ASSIGN_OR_RAISE(auto batch_gen, ExecuteAsync()); + return MakeGeneratorIterator(std::move(batch_gen)); + } + + Result ExecuteAsync() override { + ARROW_ASSIGN_OR_RAISE(auto batches_it, target_->Execute()); + ARROW_ASSIGN_OR_RAISE(auto batches, batches_it.ToVector()); + auto generator_fut = DeferNotOk(internal::GetCpuThreadPool()->Submit( + [batches] { return MakeVectorGenerator(batches); })); + return MakeFromFuture(generator_fut); + } + + bool supports_async() const override { return true; } + + private: + std::shared_ptr target_; + }; + + class NestedParallelismFragment : public InMemoryFragment { + public: + explicit NestedParallelismFragment(RecordBatchVector record_batches, + Expression expr = literal(true)) + : InMemoryFragment(std::move(record_batches), std::move(expr)) {} + + Result Scan(std::shared_ptr options) override { + ARROW_ASSIGN_OR_RAISE(auto scan_task_it, InMemoryFragment::Scan(options)); + return MakeMaybeMapIterator( + [](std::shared_ptr task) -> Result> { + return std::make_shared(std::move(task)); + }, + std::move(scan_task_it)); + } + }; + + class NestedParallelismDataset : public InMemoryDataset { + public: + NestedParallelismDataset(std::shared_ptr sch, RecordBatchVector batches) + : InMemoryDataset(std::move(sch), std::move(batches)) {} + + protected: + Result GetFragmentsImpl(Expression) override { + auto schema = this->schema(); + + auto create_fragment = + [schema]( + std::shared_ptr batch) -> Result> { + RecordBatchVector batches{batch}; + return std::make_shared(std::move(batches)); + }; + + return MakeMaybeMapIterator(std::move(create_fragment), get_batches_->Get()); + } + }; + + class DiscardingRowCountingFileWriteOptions : public FileWriteOptions { + public: + explicit DiscardingRowCountingFileWriteOptions( + std::shared_ptr> row_counter) + : FileWriteOptions( + std::make_shared(std::move(row_counter))) {} + }; + + class DiscardingRowCountingFileWriter : public FileWriter { + public: + explicit DiscardingRowCountingFileWriter(std::shared_ptr> row_count) + : FileWriter(NULL, NULL, NULL), row_count_(std::move(row_count)) {} + virtual ~DiscardingRowCountingFileWriter() = default; + + Status Write(const std::shared_ptr& batch) override { + row_count_->fetch_add(static_cast(batch->num_rows())); + return Status::OK(); + } + Status Finish() override { return Status::OK(); }; + + protected: + Status FinishInternal() override { return Status::OK(); }; + + private: + std::shared_ptr> row_count_; + }; + + class DiscardingRowCountingFormat : public FileFormat { + public: + DiscardingRowCountingFormat() : row_count_(std::make_shared>(0)) {} + explicit DiscardingRowCountingFormat(std::shared_ptr> row_count) + : row_count_(std::move(row_count)) {} + virtual ~DiscardingRowCountingFormat() = default; + + std::string type_name() const override { return "discarding-row-counting"; } + bool Equals(const FileFormat& other) const override { return true; } + Result IsSupported(const FileSource& source) const override { + return Status::NotImplemented("Should not be called"); + } + Result> Inspect(const FileSource& source) const override { + return Status::NotImplemented("Should not be called"); + } + Result ScanFile( + std::shared_ptr options, + const std::shared_ptr& file) const override { + return Status::NotImplemented("Should not be called"); + } + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options) const override { + return std::make_shared(row_count_); + } + std::shared_ptr DefaultWriteOptions() override { return NULLPTR; } + + private: + std::shared_ptr> row_count_; + }; + + protected: + int old_capacity_ = 0; + std::shared_ptr schema_; + std::shared_ptr options_; +}; + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index fc58c3d180b..a08b9e366f0 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -640,6 +640,40 @@ class SerialReadaheadGenerator { std::shared_ptr state_; }; +template +class FutureFirstGenerator { + public: + explicit FutureFirstGenerator(Future> future) + : state_(std::make_shared(std::move(future))) {} + + Future operator()() { + if (state_->source_) { + return state_->source_(); + } else { + auto state = state_; + return state_->future_.Then([state](const AsyncGenerator& source) { + state->source_ = source; + return state->source_(); + }); + } + } + + private: + struct State { + explicit State(Future> future) : future_(future), source_() {} + + Future> future_; + AsyncGenerator source_; + }; + + std::shared_ptr state_; +}; + +template +AsyncGenerator MakeFromFuture(Future> future) { + return FutureFirstGenerator(std::move(future)); +} + /// \brief Creates a generator that will pull from the source into a queue. Unlike /// MakeReadaheadGenerator this will not pull reentrantly from the source. /// diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 2cc14c5f16d..3f2d63f89d6 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -270,6 +270,7 @@ test_that("IPC/Feather format data", { }) test_that("CSV dataset", { + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-12181 ds <- open_dataset(csv_dir, partitioning = "part", format = "csv") expect_is(ds$format, "CsvFileFormat") expect_is(ds$filesystem, "LocalFileSystem")