diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 00bc9af7ea4..22250d2f993 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -616,6 +616,105 @@ Result> DeclarationToExecBatches(Declaration declaration, return DeclarationToExecBatchesAsync(std::move(declaration), exec_context).result(); } +namespace { +struct BatchConverter { + explicit BatchConverter(::arrow::internal::Executor* executor) + : exec_context(std::make_shared(default_memory_pool(), executor)) {} + + ~BatchConverter() { + if (!exec_plan) { + return; + } + if (exec_plan->finished().is_finished()) { + return; + } + exec_plan->StopProducing(); + Status abandoned_status = exec_plan->finished().status(); + if (!abandoned_status.ok()) { + abandoned_status.Warn(); + } + } + + Future> operator()() { + return exec_batch_gen().Then( + [this](const std::optional& batch) + -> Future> { + if (batch) { + return batch->ToRecordBatch(schema); + } else { + return exec_plan->finished().Then( + []() -> std::shared_ptr { return nullptr; }); + } + }, + [this](const Status& err) { + return exec_plan->finished().Then( + [err]() -> Result> { return err; }); + }); + } + + std::shared_ptr exec_context; + AsyncGenerator> exec_batch_gen; + std::shared_ptr schema; + std::shared_ptr exec_plan; +}; + +Result>> DeclarationToRecordBatchGenerator( + Declaration declaration, ::arrow::internal::Executor* executor, + std::shared_ptr* out_schema) { + auto converter = std::make_shared(executor); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, + ExecPlan::Make(converter->exec_context.get())); + Declaration with_sink = Declaration::Sequence( + {declaration, + {"sink", SinkNodeOptions(&converter->exec_batch_gen, &converter->schema)}}); + ARROW_RETURN_NOT_OK(with_sink.AddToPlan(plan.get())); + ARROW_RETURN_NOT_OK(plan->StartProducing()); + converter->exec_plan = std::move(plan); + *out_schema = converter->schema; + return [conv = std::move(converter)] { return (*conv)(); }; +} +} // namespace + +Result> DeclarationToReader(Declaration declaration, + bool use_threads) { + std::shared_ptr schema; + auto batch_iterator = std::make_unique>>( + ::arrow::internal::IterateSynchronously>( + [&](::arrow::internal::Executor* executor) + -> Result>> { + return DeclarationToRecordBatchGenerator(declaration, executor, &schema); + }, + use_threads)); + + struct PlanReader : RecordBatchReader { + PlanReader(std::shared_ptr schema, + std::unique_ptr>> iterator) + : schema_(std::move(schema)), iterator_(std::move(iterator)) {} + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* record_batch) override { + DCHECK(!!iterator_) << "call to ReadNext on already closed reader"; + return iterator_->Next().Value(record_batch); + } + + Status Close() override { + // End plan and read from generator until finished + std::shared_ptr batch; + do { + ARROW_RETURN_NOT_OK(ReadNext(&batch)); + } while (batch != nullptr); + iterator_.reset(); + return Status::OK(); + } + + std::shared_ptr schema_; + std::unique_ptr>> iterator_; + }; + + return std::make_unique(std::move(schema), std::move(batch_iterator)); +} + namespace internal { void RegisterSourceNode(ExecFactoryRegistry*); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 5d929aa3057..44cd1acf875 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -516,6 +516,10 @@ ARROW_EXPORT Result>> DeclarationToBatc ARROW_EXPORT Future>> DeclarationToBatchesAsync( Declaration declaration, ExecContext* exec_context = default_exec_context()); +/// \brief Utility method to run a declaration and return results as a RecordBatchReader +ARROW_EXPORT Result> DeclarationToReader( + Declaration declaration, bool use_threads); + /// \brief Wrap an ExecBatch generator in a RecordBatchReader. /// /// The RecordBatchReader does not impose any ordering on emitted batches. diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 11d3f2050b1..0dd1a0b9a90 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -1213,6 +1213,27 @@ Result SimplifyWithGuarantee(Expression expr, return expr; } +Result RemoveNamedRefs(Expression src) { + if (!src.IsBound()) { + return Status::Invalid("RemoveNamedRefs called on unbound expression"); + } + return ModifyExpression( + std::move(src), + /*pre=*/ + [](Expression expr) { + const Expression::Parameter* param = expr.parameter(); + if (param && !param->ref.IsFieldPath()) { + FieldPath ref_as_path( + std::vector(param->indices.begin(), param->indices.end())); + return Expression( + Expression::Parameter{std::move(ref_as_path), param->type, param->indices}); + } + + return expr; + }, + /*post_call=*/[](Expression expr, ...) { return expr; }); +} + // Serialization is accomplished by converting expressions to KeyValueMetadata and storing // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its // columns. Finally, the RecordBatch is written to an IPC file. diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index 7aeb135c994..c9c7b0e605f 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -220,6 +220,12 @@ ARROW_EXPORT Result SimplifyWithGuarantee(Expression, const Expression& guaranteed_true_predicate); +/// Replace all named field refs (e.g. "x" or "x.y") with field paths (e.g. [0] or [1,3]) +/// +/// This isn't usually needed and does not offer any simplification by itself. However, +/// it can be useful to normalize an expression to paths to make it simpler to work with. +ARROW_EXPORT Result RemoveNamedRefs(Expression expression); + /// @} // Execution diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index e15963cfcde..6dc48b3be4e 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -929,6 +929,24 @@ TEST(Expression, FoldConstantsBoolean) { ExpectFoldsTo(or_(whatever, whatever), whatever); } +void ExpectRemovesRefsTo(Expression expr, Expression expected, + const Schema& schema = *kBoringSchema) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(schema)); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(schema)); + + ASSERT_OK_AND_ASSIGN(auto without_named_refs, RemoveNamedRefs(expr)); + + EXPECT_EQ(without_named_refs, expected); +} + +TEST(Expression, RemoveNamedRefs) { + ExpectRemovesRefsTo(field_ref("i32"), field_ref(2)); + ExpectRemovesRefsTo(call("add", {literal(4), field_ref("i32")}), + call("add", {literal(4), field_ref(2)})); + auto nested_schema = Schema({field("a", struct_({field("b", int32())}))}); + ExpectRemovesRefsTo(field_ref({"a", "b"}), field_ref({0, 0}), nested_schema); +} + TEST(Expression, ExtractKnownFieldValues) { struct { void operator()(Expression guarantee, @@ -1364,6 +1382,10 @@ TEST(Expression, SimplifyWithValidityGuarantee) { .WithGuarantee(is_null(field_ref("i32"))) .Expect(literal(false)); + Simplify{{true_unless_null(field_ref("i32"))}} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(null_literal(boolean())); + Simplify{is_valid(field_ref("i32"))} .WithGuarantee(is_valid(field_ref("i32"))) .Expect(literal(true)); @@ -1379,6 +1401,21 @@ TEST(Expression, SimplifyWithValidityGuarantee) { Simplify{true_unless_null(field_ref("i32"))} .WithGuarantee(is_valid(field_ref("i32"))) .Expect(literal(true)); + + Simplify{{equal(field_ref("i32"), literal(7))}} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(null_literal(boolean())); + + auto i32_is_2_or_null = + or_(equal(field_ref("i32"), literal(2)), is_null(field_ref("i32"))); + + Simplify{i32_is_2_or_null} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(literal(true)); + + Simplify{{greater(field_ref("i32"), literal(7))}} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(null_literal(boolean())); } TEST(Expression, SimplifyWithComparisonAndNullableCaveat) { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 8600b113489..4ab5b8a5396 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -214,9 +214,19 @@ struct ARROW_EXPORT BackpressureOptions { class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { public: explicit SinkNodeOptions(std::function>()>* generator, + std::shared_ptr* schema, BackpressureOptions backpressure = {}, BackpressureMonitor** backpressure_monitor = NULLPTR) : generator(generator), + schema(schema), + backpressure(backpressure), + backpressure_monitor(backpressure_monitor) {} + + explicit SinkNodeOptions(std::function>()>* generator, + BackpressureOptions backpressure = {}, + BackpressureMonitor** backpressure_monitor = NULLPTR) + : generator(generator), + schema(NULLPTR), backpressure(std::move(backpressure)), backpressure_monitor(backpressure_monitor) {} @@ -226,6 +236,11 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { /// data from the plan. If this function is not called frequently enough then the sink /// node will start to accumulate data and may apply backpressure. std::function>()>* generator; + /// \brief A pointer which will be set to the schema of the generated batches + /// + /// This is optional, if nullptr is passed in then it will be ignored. + /// This will be set when the node is added to the plan, before StartProducing is called + std::shared_ptr* schema; /// \brief Options to control when to apply backpressure /// /// This is optional, the default is to never apply backpressure. If the plan is not diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 6c8d497a1d6..25ef3d73a8c 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -390,13 +390,14 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { BackpressureMonitor* backpressure_monitor; BackpressureOptions backpressure_options(resume_if_below_bytes, pause_if_above_bytes); std::shared_ptr schema_ = schema({field("data", uint32())}); - ARROW_EXPECT_OK(compute::Declaration::Sequence( - { - {"source", SourceNodeOptions(schema_, batch_producer)}, - {"sink", SinkNodeOptions{&sink_gen, backpressure_options, - &backpressure_monitor}}, - }) - .AddToPlan(plan.get())); + ARROW_EXPECT_OK( + compute::Declaration::Sequence( + { + {"source", SourceNodeOptions(schema_, batch_producer)}, + {"sink", SinkNodeOptions{&sink_gen, /*schema=*/nullptr, + backpressure_options, &backpressure_monitor}}, + }) + .AddToPlan(plan.get())); ASSERT_TRUE(backpressure_monitor); ARROW_EXPECT_OK(plan->StartProducing()); diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 2220064e567..f69a2ebfb62 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -91,7 +91,7 @@ class SinkNode : public ExecNode { public: SinkNode(ExecPlan* plan, std::vector inputs, AsyncGenerator>* generator, - BackpressureOptions backpressure, + std::shared_ptr* schema, BackpressureOptions backpressure, BackpressureMonitor** backpressure_monitor_out) : ExecNode(plan, std::move(inputs), {"collected"}, {}, /*num_outputs=*/0), @@ -103,6 +103,9 @@ class SinkNode : public ExecNode { *backpressure_monitor_out = &backpressure_queue_; } auto node_destroyed_capture = node_destroyed_; + if (schema) { + *schema = inputs_[0]->output_schema(); + } *generator = [this, node_destroyed_capture]() -> Future> { if (*node_destroyed_capture) { return Status::Invalid( @@ -126,7 +129,7 @@ class SinkNode : public ExecNode { const auto& sink_options = checked_cast(options); RETURN_NOT_OK(ValidateOptions(sink_options)); return plan->EmplaceNode(plan, std::move(inputs), sink_options.generator, - sink_options.backpressure, + sink_options.schema, sink_options.backpressure, sink_options.backpressure_monitor); } @@ -414,7 +417,8 @@ struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, std::unique_ptr impl, AsyncGenerator>* generator) - : SinkNode(plan, std::move(inputs), generator, /*backpressure=*/{}, + : SinkNode(plan, std::move(inputs), generator, /*schema=*/nullptr, + /*backpressure=*/{}, /*backpressure_monitor_out=*/nullptr), impl_(std::move(impl)) {} diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 4e1d7ac20ec..f09a878b511 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -43,12 +43,14 @@ Fragment::Fragment(compute::Expression partition_expression, : partition_expression_(std::move(partition_expression)), physical_schema_(std::move(physical_schema)) {} -Future> Fragment::InspectFragment() { +Future> Fragment::InspectFragment( + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { return Status::NotImplemented("Inspect fragment"); } Future> Fragment::BeginScan( - const FragmentScanRequest& request, const InspectedFragment& inspected_fragment) { + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { return Status::NotImplemented("New scan method"); } @@ -154,7 +156,8 @@ Future> InMemoryFragment::CountRows( return Future>::MakeFinished(total); } -Future> InMemoryFragment::InspectFragment() { +Future> InMemoryFragment::InspectFragment( + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { return std::make_shared(physical_schema_->field_names()); } @@ -180,7 +183,8 @@ class InMemoryFragment::Scanner : public FragmentScanner { }; Future> InMemoryFragment::BeginScan( - const FragmentScanRequest& request, const InspectedFragment& inspected_fragment) { + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { return Future>::MakeFinished( std::make_shared(this)); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 80c46568a7c..80e9e96136a 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -88,7 +88,7 @@ struct ARROW_DS_EXPORT FragmentScanRequest { /// before returning the scanned batch. std::vector columns; /// \brief Options specific to the format being scanned - FragmentScanOptions* format_scan_options; + const FragmentScanOptions* format_scan_options; }; /// \brief An iterator-like object that can yield batches created from a fragment @@ -156,11 +156,13 @@ class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this { /// This will be called before a scan and a fragment should attach whatever /// information will be needed to figure out an evolution strategy. This information /// will then be passed to the call to BeginScan - virtual Future> InspectFragment(); + virtual Future> InspectFragment( + const FragmentScanOptions* format_options, compute::ExecContext* exec_context); /// \brief Start a scan operation virtual Future> BeginScan( - const FragmentScanRequest& request, const InspectedFragment& inspected_fragment); + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context); /// \brief Count the number of rows in this fragment matching the filter using metadata /// only. That is, this method may perform I/O, but will not load data. @@ -228,10 +230,13 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { compute::Expression predicate, const std::shared_ptr& options) override; - Future> InspectFragment() override; + Future> InspectFragment( + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; Future> BeginScan( - const FragmentScanRequest& request, - const InspectedFragment& inspected_fragment) override; + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; std::string type_name() const override { return "in-memory"; } diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 5b37f535b26..ffa25332ad9 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -54,6 +54,19 @@ using internal::checked_pointer_cast; namespace dataset { +FileSource::FileSource(std::shared_ptr file, + Compression::type compression) + : custom_open_([=] { return ToResult(file); }), + custom_size_(-1), + compression_(compression) { + Result maybe_size = file->GetSize(); + if (maybe_size.ok()) { + custom_size_ = *maybe_size; + } else { + custom_open_ = [st = maybe_size.status()] { return st; }; + } +} + Result> FileSource::Open() const { if (filesystem_) { return filesystem_->OpenInputFile(file_info_); @@ -66,6 +79,16 @@ Result> FileSource::Open() const { return custom_open_(); } +int64_t FileSource::Size() const { + if (filesystem_) { + return file_info_.size(); + } + if (buffer_) { + return buffer_->size(); + } + return custom_size_; +} + Result> FileSource::OpenCompressed( std::optional compression) const { ARROW_ASSIGN_OR_RAISE(auto file, Open()); @@ -108,6 +131,18 @@ Future> FileFormat::CountRows( return Future>::MakeFinished(std::nullopt); } +Future> FileFormat::InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const { + return Status::NotImplemented("This format does not yet support the scan2 node"); +} + +Future> FileFormat::BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) const { + return Status::NotImplemented("This format does not yet support the scan2 node"); +} + Result> FileFormat::MakeFragment( FileSource source, std::shared_ptr physical_schema) { return MakeFragment(std::move(source), compute::literal(true), @@ -137,6 +172,26 @@ Result FileFragment::ScanBatchesAsync( return format_->ScanBatchesAsync(options, self); } +Future> FileFragment::InspectFragment( + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { + const FragmentScanOptions* realized_format_options = format_options; + if (format_options == nullptr) { + realized_format_options = format_->default_fragment_scan_options.get(); + } + return format_->InspectFragment(source_, realized_format_options, exec_context); +} + +Future> FileFragment::BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) { + const FragmentScanOptions* realized_format_options = format_options; + if (format_options == nullptr) { + realized_format_options = format_->default_fragment_scan_options.get(); + } + return format_->BeginScan(request, inspected_fragment, realized_format_options, + exec_context); +} + Future> FileFragment::CountRows( compute::Expression predicate, const std::shared_ptr& options) { ARROW_ASSIGN_OR_RAISE(predicate, compute::SimplifyWithGuarantee(std::move(predicate), diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 586c58b3f52..dab7510d5b5 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -65,18 +65,25 @@ class ARROW_DS_EXPORT FileSource : public util::EqualityComparable { : buffer_(std::move(buffer)), compression_(compression) {} using CustomOpen = std::function>()>; - explicit FileSource(CustomOpen open) : custom_open_(std::move(open)) {} + FileSource(CustomOpen open, int64_t size) + : custom_open_(std::move(open)), custom_size_(size) {} using CustomOpenWithCompression = std::function>(Compression::type)>; - explicit FileSource(CustomOpenWithCompression open_with_compression, - Compression::type compression = Compression::UNCOMPRESSED) + FileSource(CustomOpenWithCompression open_with_compression, int64_t size, + Compression::type compression = Compression::UNCOMPRESSED) : custom_open_(std::bind(std::move(open_with_compression), compression)), + custom_size_(size), + compression_(compression) {} + + FileSource(std::shared_ptr file, int64_t size, + Compression::type compression = Compression::UNCOMPRESSED) + : custom_open_([=] { return ToResult(file); }), + custom_size_(size), compression_(compression) {} explicit FileSource(std::shared_ptr file, - Compression::type compression = Compression::UNCOMPRESSED) - : custom_open_([=] { return ToResult(file); }), compression_(compression) {} + Compression::type compression = Compression::UNCOMPRESSED); FileSource() : custom_open_(CustomOpen{&InvalidOpen}) {} @@ -108,6 +115,10 @@ class ARROW_DS_EXPORT FileSource : public util::EqualityComparable { /// \brief Get a RandomAccessFile which views this file source Result> Open() const; + /// \brief Get the size (in bytes) of the file or buffer + /// If the file is compressed this should be the compressed (on-disk) size. + int64_t Size() const; + /// \brief Get an InputStream which views this file source (and decompresses if needed) /// \param[in] compression If nullopt, guess the compression scheme from the /// filename, else decompress with the given codec @@ -126,6 +137,7 @@ class ARROW_DS_EXPORT FileSource : public util::EqualityComparable { std::shared_ptr filesystem_; std::shared_ptr buffer_; CustomOpen custom_open_; + int64_t custom_size_ = 0; Compression::type compression_ = Compression::UNCOMPRESSED; }; @@ -150,6 +162,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> Inspect(const FileSource& source) const = 0; + /// \brief Learn what we need about the file before we start scanning it + virtual Future> InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const; + virtual Result ScanBatchesAsync( const std::shared_ptr& options, const std::shared_ptr& file) const = 0; @@ -158,6 +175,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this& file, compute::Expression predicate, const std::shared_ptr& options); + virtual Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const; + /// \brief Open a fragment virtual Result> MakeFragment( FileSource source, compute::Expression partition_expression, @@ -179,6 +201,10 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this DefaultWriteOptions() = 0; + + protected: + explicit FileFormat(std::shared_ptr default_fragment_scan_options) + : default_fragment_scan_options(std::move(default_fragment_scan_options)) {} }; /// \brief A Fragment that is stored in a file with a known format @@ -190,6 +216,13 @@ class ARROW_DS_EXPORT FileFragment : public Fragment, Future> CountRows( compute::Expression predicate, const std::shared_ptr& options) override; + Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; + Future> InspectFragment( + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; std::string type_name() const override { return format_->type_name(); } std::string ToString() const override { return source_.path(); }; diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index be963338b40..122e7f79708 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -38,6 +38,7 @@ #include "arrow/result.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" +#include "arrow/util/bit_util.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/tracing_internal.h" @@ -52,9 +53,98 @@ using internal::SerialExecutor; namespace dataset { +struct CsvInspectedFragment : public InspectedFragment { + CsvInspectedFragment(std::vector column_names, + std::shared_ptr input_stream, int64_t num_bytes) + : InspectedFragment(std::move(column_names)), + input_stream(std::move(input_stream)), + num_bytes(num_bytes) {} + // We need to start reading the file in order to figure out the column names and + // so we save off the input stream + std::shared_ptr input_stream; + int64_t num_bytes; +}; + +class CsvFileScanner : public FragmentScanner { + public: + CsvFileScanner(std::shared_ptr reader, int num_batches, + int64_t best_guess_bytes_per_batch) + : reader_(std::move(reader)), + num_batches_(num_batches), + best_guess_bytes_per_batch_(best_guess_bytes_per_batch) {} + + Future> ScanBatch(int batch_number) override { + // This should be called in increasing order but let's verify that in case it changes. + // It would be easy enough to handle out of order but no need for that complexity at + // the moment. + DCHECK_EQ(scanned_so_far_++, batch_number); + return reader_->ReadNextAsync(); + } + + int64_t EstimatedDataBytes(int batch_number) override { + return best_guess_bytes_per_batch_; + } + + int NumBatches() override { return num_batches_; } + + static Result GetConvertOptions( + const CsvFragmentScanOptions& csv_options, const FragmentScanRequest& scan_request, + const CsvInspectedFragment& inspected_fragment) { + // We use the convert options given from the user but override which columns we are + // looking for. + auto convert_options = csv_options.convert_options; + std::vector columns; + std::unordered_map> column_types; + for (const auto& scan_column : scan_request.columns) { + if (scan_column.path.indices().size() != 1) { + return Status::Invalid("CSV reader does not supported nested references"); + } + const std::string& column_name = + inspected_fragment.column_names[scan_column.path.indices()[0]]; + columns.push_back(column_name); + column_types[column_name] = scan_column.requested_type->GetSharedPtr(); + } + convert_options.include_columns = std::move(columns); + convert_options.column_types = std::move(column_types); + return std::move(convert_options); + } + + static Future> Make( + const CsvFragmentScanOptions& csv_options, const FragmentScanRequest& scan_request, + const CsvInspectedFragment& inspected_fragment, Executor* cpu_executor) { + auto read_options = csv_options.read_options; + + int num_batches = static_cast(bit_util::CeilDiv( + inspected_fragment.num_bytes, static_cast(read_options.block_size))); + // Could be better, but a reasonable starting point. CSV presumably takes up more + // space than an in-memory format so this should be conservative. + int64_t best_guess_bytes_per_batch = read_options.block_size; + ARROW_ASSIGN_OR_RAISE( + csv::ConvertOptions convert_options, + GetConvertOptions(csv_options, scan_request, inspected_fragment)); + + return csv::StreamingReader::MakeAsync( + io::default_io_context(), inspected_fragment.input_stream, cpu_executor, + read_options, csv_options.parse_options, convert_options) + .Then([num_batches, best_guess_bytes_per_batch]( + const std::shared_ptr& reader) + -> std::shared_ptr { + return std::make_shared(reader, num_batches, + best_guess_bytes_per_batch); + }); + } + + private: + std::shared_ptr reader_; + int num_batches_; + int64_t best_guess_bytes_per_batch_; + + int scanned_so_far_ = 0; +}; + using RecordBatchGenerator = std::function>()>; -Result> GetColumnNames( +Result> GetOrderedColumnNames( const csv::ReadOptions& read_options, const csv::ParseOptions& parse_options, std::string_view first_block, MemoryPool* pool) { // Skip BOM when reading column names (ARROW-14644, ARROW-17382) @@ -64,13 +154,7 @@ Result> GetColumnNames( size = size - static_cast(data_no_bom - data); first_block = std::string_view(reinterpret_cast(data_no_bom), size); if (!read_options.column_names.empty()) { - std::unordered_set column_names; - for (const auto& s : read_options.column_names) { - if (!column_names.emplace(s).second) { - return Status::Invalid("CSV file contained multiple columns named ", s); - } - } - return column_names; + return read_options.column_names; } uint32_t parsed_size = 0; @@ -90,14 +174,14 @@ Result> GetColumnNames( return Status::Invalid("No columns in CSV file"); } - std::unordered_set column_names; + std::vector column_names; if (read_options.autogenerate_column_names) { column_names.reserve(parser.num_cols()); for (int32_t i = 0; i < parser.num_cols(); ++i) { std::stringstream ss; ss << "f" << i; - column_names.emplace(ss.str()); + column_names.emplace_back(ss.str()); } return column_names; } @@ -105,15 +189,28 @@ Result> GetColumnNames( RETURN_NOT_OK( parser.VisitLastRow([&](const uint8_t* data, uint32_t size, bool quoted) -> Status { std::string_view view{reinterpret_cast(data), size}; - if (column_names.emplace(std::string(view)).second) { - return Status::OK(); - } - return Status::Invalid("CSV file contained multiple columns named ", view); + column_names.emplace_back(view); + return Status::OK(); })); return column_names; } +Result> GetColumnNames( + const csv::ReadOptions& read_options, const csv::ParseOptions& parse_options, + std::string_view first_block, MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE( + std::vector ordered_names, + GetOrderedColumnNames(read_options, parse_options, first_block, pool)); + std::unordered_set unordered_names; + for (const auto& column : ordered_names) { + if (!unordered_names.emplace(column).second) { + return Status::Invalid("CSV file contained multiple columns named ", column); + } + } + return unordered_names; +} + static inline Result GetConvertOptions( const CsvFileFormat& format, const ScanOptions* scan_options, const std::string_view first_block) { @@ -251,6 +348,8 @@ static RecordBatchGenerator GeneratorFromReader( return MakeFromFuture(std::move(gen_fut)); } +CsvFileFormat::CsvFileFormat() : FileFormat(std::make_shared()) {} + bool CsvFileFormat::Equals(const FileFormat& format) const { if (type_name() != format.type_name()) return false; @@ -312,6 +411,52 @@ Future> CsvFileFormat::CountRows( .Then([](int64_t count) { return std::make_optional(count); }); } +Future> CsvFileFormat::BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context) const { + auto csv_options = static_cast(format_options); + auto csv_fragment = static_cast(inspected_fragment); + return CsvFileScanner::Make(*csv_options, request, csv_fragment, + exec_context->executor()); +} + +Result> DoInspectFragment( + const FileSource& source, const CsvFragmentScanOptions& csv_options, + compute::ExecContext* exec_context) { + ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed()); + if (csv_options.stream_transform_func) { + ARROW_ASSIGN_OR_RAISE(input, csv_options.stream_transform_func(input)); + } + ARROW_ASSIGN_OR_RAISE( + input, io::BufferedInputStream::Create(csv_options.read_options.block_size, + default_memory_pool(), std::move(input))); + + ARROW_ASSIGN_OR_RAISE(std::string_view first_block, + input->Peek(csv_options.read_options.block_size)); + + ARROW_ASSIGN_OR_RAISE( + std::vector column_names, + GetOrderedColumnNames(csv_options.read_options, csv_options.parse_options, + first_block, exec_context->memory_pool())); + return std::make_shared(std::move(column_names), std::move(input), + source.Size()); +} + +Future> CsvFileFormat::InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const { + auto csv_options = static_cast(format_options); + Executor* io_executor; + if (source.filesystem()) { + io_executor = source.filesystem()->io_context().executor(); + } else { + io_executor = exec_context->executor(); + } + return DeferNotOk(io_executor->Submit([source, csv_options, exec_context]() { + return DoInspectFragment(source, *csv_options, exec_context); + })); +} + // // CsvFileWriter, CsvFileWriteOptions // diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index a3d214ef494..42e3fd72469 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -41,9 +41,12 @@ constexpr char kCsvTypeName[] = "csv"; /// \brief A FileFormat implementation that reads from and writes to Csv files class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { public: + // TODO(ARROW-18328) Remove this, moved to CsvFragmentScanOptions /// Options affecting the parsing of CSV files csv::ParseOptions parse_options = csv::ParseOptions::Defaults(); + CsvFileFormat(); + std::string type_name() const override { return kCsvTypeName; } bool Equals(const FileFormat& other) const override; @@ -53,10 +56,19 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { /// \brief Return the schema of the file if possible. Result> Inspect(const FileSource& source) const override; + Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + Result ScanBatchesAsync( const std::shared_ptr& scan_options, const std::shared_ptr& file) const override; + Future> InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) override; @@ -84,6 +96,9 @@ struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { /// Note that use_threads is always ignored. csv::ReadOptions read_options = csv::ReadOptions::Defaults(); + /// CSV parse options + csv::ParseOptions parse_options = csv::ParseOptions::Defaults(); + /// Optional stream wrapping function /// /// If defined, all open dataset file fragments will be passed diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc index 2ee62935f02..dd01a5aa3e7 100644 --- a/cpp/src/arrow/dataset/file_csv_test.cc +++ b/cpp/src/arrow/dataset/file_csv_test.cc @@ -25,6 +25,7 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/partition.h" +#include "arrow/dataset/plan.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/mockfs.h" #include "arrow/io/compressed.h" @@ -58,10 +59,23 @@ class CsvFormatHelper { } }; +struct CsvFileFormatParams { + Compression::type compression_type; + bool use_new_scan_v2; +}; + class TestCsvFileFormat : public FileFormatFixtureMixin, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface { public: - Compression::type GetCompression() { return GetParam(); } + bool UseScanV2() { return GetParam().use_new_scan_v2; } + Compression::type GetCompression() { return GetParam().compression_type; } + + void SetUp() { + internal::Initialize(); + auto fragment_scan_options = std::make_shared(); + fragment_scan_options->parse_options.ignore_empty_lines = false; + opts_->fragment_scan_options = fragment_scan_options; + } std::unique_ptr GetFileSource(std::string csv) { if (GetCompression() == Compression::UNCOMPRESSED) { @@ -96,9 +110,44 @@ class TestCsvFileFormat : public FileFormatFixtureMixin, return std::make_unique(info, fs, GetCompression()); } - RecordBatchIterator Batches(Fragment* fragment) { - EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); - return MakeGeneratorIterator(batch_gen); + CsvFragmentScanOptions MakeDefaultFormatOptions() { + CsvFragmentScanOptions scan_opts; + scan_opts.parse_options.ignore_empty_lines = false; + return scan_opts; + } + + ScanV2Options MigrateLegacyOptions(std::shared_ptr fragment) { + std::shared_ptr dataset = std::make_shared( + opts_->dataset_schema, FragmentVector{std::move(fragment)}); + ScanV2Options updated(std::move(dataset)); + updated.format_options = opts_->fragment_scan_options.get(); + updated.filter = opts_->filter; + for (const auto& field : opts_->projected_schema->fields()) { + auto field_name = field->name(); + EXPECT_OK_AND_ASSIGN(FieldPath field_path, + FieldRef(field_name).FindOne(*opts_->dataset_schema)); + updated.columns.push_back(field_path); + } + return updated; + } + + RecordBatchIterator Batches(const std::shared_ptr& fragment) { + if (UseScanV2()) { + ScanV2Options v2_options = MigrateLegacyOptions(fragment); + EXPECT_TRUE(ScanV2Options::AddFieldsNeededForFilter(&v2_options).ok()); + EXPECT_OK_AND_ASSIGN(std::unique_ptr reader, + compute::DeclarationToReader( + compute::Declaration("scan2", std::move(v2_options)), + /*use_threads=*/false)); + struct ReaderIterator { + Result> Next() { return reader->Next(); } + std::unique_ptr reader; + }; + return RecordBatchIterator(ReaderIterator{std::move(reader)}); + } else { + EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); + return MakeGeneratorIterator(batch_gen); + } } }; @@ -111,7 +160,7 @@ TEST_P(TestCsvFileFormat, BOMQuoteInHeader) { int64_t row_count = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); AssertSchemaEqual(batch->schema(), schema(fields)); row_count += batch->num_rows(); @@ -133,7 +182,7 @@ N/A int64_t row_count = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); row_count += batch->num_rows(); } @@ -149,13 +198,13 @@ N/A bar)"); SetSchema({field("str", utf8())}); auto fragment = MakeFragment(*source); - auto fragment_scan_options = std::make_shared(); + auto fragment_scan_options = + static_cast(opts_->fragment_scan_options.get()); fragment_scan_options->convert_options.null_values = {"MYNULL"}; fragment_scan_options->convert_options.strings_can_be_null = true; - opts_->fragment_scan_options = fragment_scan_options; int64_t null_count = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); null_count += batch->GetColumnByName("str")->null_count(); } @@ -175,12 +224,13 @@ bar)"); auto defaults = std::make_shared(); defaults->read_options.skip_rows = 1; format_->default_fragment_scan_options = defaults; + opts_->fragment_scan_options = nullptr; auto fragment = MakeFragment(*source); ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema()); AssertSchemaEqual(opts_->dataset_schema, physical_schema); int64_t rows = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); rows += batch->GetColumnByName("str")->length(); } @@ -189,12 +239,13 @@ bar)"); { SetSchema({field("header_skipped", utf8())}); // These options completely override the default ones - auto fragment_scan_options = std::make_shared(); + opts_->fragment_scan_options = std::make_shared(); + auto fragment_scan_options = + static_cast(opts_->fragment_scan_options.get()); fragment_scan_options->read_options.block_size = 1 << 22; - opts_->fragment_scan_options = fragment_scan_options; int64_t rows = 0; auto fragment = MakeFragment(*source); - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); rows += batch->GetColumnByName("header_skipped")->length(); } @@ -208,7 +259,7 @@ bar)"); opts_->fragment_scan_options = nullptr; int64_t rows = 0; auto fragment = MakeFragment(*source); - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); rows += batch->GetColumnByName("custom_header")->length(); } @@ -222,11 +273,12 @@ TEST_P(TestCsvFileFormat, CustomReadOptionsColumnNames) { auto defaults = std::make_shared(); defaults->read_options.column_names = {"ints_1", "ints_2"}; format_->default_fragment_scan_options = defaults; + opts_->fragment_scan_options = nullptr; auto fragment = MakeFragment(*source); ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema()); AssertSchemaEqual(opts_->dataset_schema, physical_schema); int64_t rows = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); rows += batch->num_rows(); } @@ -235,9 +287,25 @@ TEST_P(TestCsvFileFormat, CustomReadOptionsColumnNames) { defaults->read_options.column_names = {"same", "same"}; format_->default_fragment_scan_options = defaults; fragment = MakeFragment(*source); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("CSV file contained multiple columns named same"), - Batches(fragment.get()).Next()); + SetSchema({field("same", int64())}); + if (UseScanV2()) { + // V2 scan method's basic evolution strategy builds ds_to_frag_map and just finds + // the first instance of a matching column and doesn't check further to see if + // there are duplicates. So in this case it would grab the first column. + // + // Not clear if this is a good thing or not. + rows = 0; + for (auto maybe_batch : Batches(fragment)) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + rows += batch->num_rows(); + } + ASSERT_EQ(rows, 2); + } else { + // Legacy scan method does not support CSV columns with duplicate names + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("CSV file contained multiple columns named same"), + Batches(fragment).Next()); + } } TEST_P(TestCsvFileFormat, ScanRecordBatchReaderWithVirtualColumn) { @@ -255,12 +323,17 @@ N/A int64_t row_count = 0; - for (auto maybe_batch : Batches(fragment.get())) { + for (auto maybe_batch : Batches(fragment)) { ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); - AssertSchemaEqual(*batch->schema(), *physical_schema); + if (UseScanV2()) { + // In the new scan, evolution happens and inserts a null column in place of the + // virtual column + AssertSchemaEqual(*batch->schema(), *opts_->dataset_schema); + } else { + AssertSchemaEqual(*batch->schema(), *physical_schema); + } row_count += batch->num_rows(); } - ASSERT_EQ(row_count, 4); } @@ -379,23 +452,39 @@ TEST_P(TestCsvFileFormat, CountRows) { TestCountRows(); } TEST_P(TestCsvFileFormat, FragmentEquals) { TestFragmentEquals(); } INSTANTIATE_TEST_SUITE_P(TestUncompressedCsv, TestCsvFileFormat, - ::testing::Values(Compression::UNCOMPRESSED)); + ::testing::Values(CsvFileFormatParams{Compression::UNCOMPRESSED, + false})); +INSTANTIATE_TEST_SUITE_P(TestUncompressedCsvV2, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::UNCOMPRESSED, + true})); #ifdef ARROW_WITH_BZ2 INSTANTIATE_TEST_SUITE_P(TestBZ2Csv, TestCsvFileFormat, - ::testing::Values(Compression::BZ2)); + ::testing::Values(CsvFileFormatParams{Compression::BZ2, false})); +INSTANTIATE_TEST_SUITE_P(TestBZ2CsvV2, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::BZ2, true})); #endif #ifdef ARROW_WITH_LZ4 INSTANTIATE_TEST_SUITE_P(TestLZ4Csv, TestCsvFileFormat, - ::testing::Values(Compression::LZ4_FRAME)); + ::testing::Values(CsvFileFormatParams{Compression::LZ4_FRAME, + false})); +INSTANTIATE_TEST_SUITE_P(TestLZ4CsvV2, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::LZ4_FRAME, + true})); #endif // Snappy does not support streaming compression #ifdef ARROW_WITH_ZLIB -INSTANTIATE_TEST_SUITE_P(TestGZipCsv, TestCsvFileFormat, - ::testing::Values(Compression::GZIP)); +INSTANTIATE_TEST_SUITE_P(TestGzipCsv, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::GZIP, + false})); +INSTANTIATE_TEST_SUITE_P(TestGzipCsvV2, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::GZIP, true})); #endif #ifdef ARROW_WITH_ZSTD INSTANTIATE_TEST_SUITE_P(TestZSTDCsv, TestCsvFileFormat, - ::testing::Values(Compression::ZSTD)); + ::testing::Values(CsvFileFormatParams{Compression::ZSTD, + false})); +INSTANTIATE_TEST_SUITE_P(TestZSTDCsvV2, TestCsvFileFormat, + ::testing::Values(CsvFileFormatParams{Compression::ZSTD, true})); #endif class TestCsvFileFormatScan : public FileFormatScanMixin {}; @@ -422,5 +511,35 @@ INSTANTIATE_TEST_SUITE_P(TestScan, TestCsvFileFormatScan, ::testing::ValuesIn(TestFormatParams::Values()), TestFormatParams::ToTestNameString); +class TestCsvFileFormatScanNode : public FileFormatScanNodeMixin { + void SetUp() override { + internal::Initialize(); + scan_options_.parse_options.ignore_empty_lines = false; + } + + const FragmentScanOptions* GetFormatOptions() override { return &scan_options_; } + + protected: + CsvFragmentScanOptions scan_options_; +}; + +TEST_P(TestCsvFileFormatScanNode, Scan) { TestScan(); } +TEST_P(TestCsvFileFormatScanNode, ScanProjected) { TestScanProjected(); } +TEST_P(TestCsvFileFormatScanNode, ScanMissingFilterField) { + TestScanMissingFilterField(); +} +// NOTE(ARROW-14658): TestScanProjectedNested is ignored since CSV +// doesn't have any nested types for us to work with +TEST_P(TestCsvFileFormatScanNode, ScanProjectedMissingColumns) { + TestScanProjectedMissingCols(); +} +TEST_P(TestCsvFileFormatScanNode, ScanWithDuplicateColumn) { + TestScanWithDuplicateColumn(); +} +// NOTE: TestScanWithPushdownNulls is ignored since CSV doesn't handle pushdown filtering +INSTANTIATE_TEST_SUITE_P(TestScanNode, TestCsvFileFormatScanNode, + ::testing::ValuesIn(TestFormatParams::Values()), + TestFormatParams::ToTestNameString); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index 2650db499ce..a8a19c714c1 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -124,6 +124,8 @@ static inline Result GetReadOptions( return options; } +IpcFileFormat::IpcFileFormat() : FileFormat(std::make_shared()) {} + Result IpcFileFormat::IsSupported(const FileSource& source) const { RETURN_NOT_OK(source.Open().status()); return OpenReader(source).ok(); diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index 8b97046271b..0f7da82a0af 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -43,6 +43,8 @@ class ARROW_DS_EXPORT IpcFileFormat : public FileFormat { public: std::string type_name() const override { return kIpcTypeName; } + IpcFileFormat(); + bool Equals(const FileFormat& other) const override { return type_name() == other.type_name(); } diff --git a/cpp/src/arrow/dataset/file_orc.cc b/cpp/src/arrow/dataset/file_orc.cc index cf04e5e7484..1393df57f9d 100644 --- a/cpp/src/arrow/dataset/file_orc.cc +++ b/cpp/src/arrow/dataset/file_orc.cc @@ -142,6 +142,8 @@ class OrcScanTaskIterator { } // namespace +OrcFileFormat::OrcFileFormat() : FileFormat(/*default_fragment_scan_options=*/nullptr) {} + Result OrcFileFormat::IsSupported(const FileSource& source) const { RETURN_NOT_OK(source.Open().status()); return OpenORCReader(source).ok(); diff --git a/cpp/src/arrow/dataset/file_orc.h b/cpp/src/arrow/dataset/file_orc.h index cbfb83670cb..5bfefd1e02b 100644 --- a/cpp/src/arrow/dataset/file_orc.h +++ b/cpp/src/arrow/dataset/file_orc.h @@ -40,6 +40,8 @@ constexpr char kOrcTypeName[] = "orc"; /// \brief A FileFormat implementation that reads from and writes to ORC files class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { public: + OrcFileFormat(); + std::string type_name() const override { return kOrcTypeName; } bool Equals(const FileFormat& other) const override { diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 30bb4840d31..7080a766abe 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -306,6 +306,9 @@ Result IsSupportedParquetFile(const ParquetFileFormat& format, } // namespace +ParquetFileFormat::ParquetFileFormat() + : FileFormat(std::make_shared()) {} + bool ParquetFileFormat::Equals(const FileFormat& other) const { if (other.type_name() != type_name()) return false; @@ -318,10 +321,11 @@ bool ParquetFileFormat::Equals(const FileFormat& other) const { other_reader_options.coerce_int96_timestamp_unit); } -ParquetFileFormat::ParquetFileFormat(const parquet::ReaderProperties& reader_properties) { - auto parquet_scan_options = std::make_shared(); - *parquet_scan_options->reader_properties = reader_properties; - default_fragment_scan_options = std::move(parquet_scan_options); +ParquetFileFormat::ParquetFileFormat(const parquet::ReaderProperties& reader_properties) + : FileFormat(std::make_shared()) { + auto* default_scan_opts = + static_cast(default_fragment_scan_options.get()); + *default_scan_opts->reader_properties = reader_properties; } Result ParquetFileFormat::IsSupported(const FileSource& source) const { diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 05c02940d35..1087fb9f9de 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -66,7 +66,7 @@ constexpr char kParquetTypeName[] = "parquet"; /// \brief A FileFormat implementation that reads from Parquet files class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { public: - ParquetFileFormat() = default; + ParquetFileFormat(); /// Convenience constructor which copies properties from a parquet::ReaderProperties. /// memory_pool will be ignored. diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 6d866c196b3..09a1fe1171b 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -94,6 +94,8 @@ constexpr int kNumBatches = 4; constexpr int kRowsPerBatch = 1024; class MockFileFormat : public FileFormat { public: + MockFileFormat() : FileFormat(/*default_fragment_scan_options=*/nullptr) {} + Result ScanBatchesAsync( const std::shared_ptr& options, const std::shared_ptr& file) const override { diff --git a/cpp/src/arrow/dataset/scan_node.cc b/cpp/src/arrow/dataset/scan_node.cc index 23a292f5d86..4e644c1de76 100644 --- a/cpp/src/arrow/dataset/scan_node.cc +++ b/cpp/src/arrow/dataset/scan_node.cc @@ -52,29 +52,27 @@ Result> OutputSchemaFromOptions(const ScanV2Options& opt // In the future we should support async scanning of fragments. The // Dataset class doesn't support this yet but we pretend it does here to // ease future adoption of the feature. -AsyncGenerator> GetFragments(Dataset* dataset, - cp::Expression predicate) { +Future>> GetFragments(Dataset* dataset, + cp::Expression predicate) { // In the future the dataset should be responsible for figuring out // the I/O context. This will allow different I/O contexts to be used // when scanning different datasets. For example, if we are scanning a // union of a remote dataset and a local dataset. const auto& io_context = io::default_io_context(); auto io_executor = io_context.executor(); - Future> fragments_it_fut = - DeferNotOk(io_executor->Submit( - [dataset, predicate]() -> Result> { - ARROW_ASSIGN_OR_RAISE(FragmentIterator fragments_iter, - dataset->GetFragments(predicate)); - return std::make_shared(std::move(fragments_iter)); - })); - Future>> fragments_gen_fut = - fragments_it_fut.Then([](const std::shared_ptr& fragments_it) - -> Result>> { + return DeferNotOk( + io_executor->Submit( + [dataset, predicate]() -> Result> { + ARROW_ASSIGN_OR_RAISE(FragmentIterator fragments_iter, + dataset->GetFragments(predicate)); + return std::make_shared(std::move(fragments_iter)); + })) + .Then([](const std::shared_ptr& fragments_it) + -> Result>> { ARROW_ASSIGN_OR_RAISE(std::vector> fragments, fragments_it->ToVector()); return MakeVectorGenerator(std::move(fragments)); }); - return MakeFromFuture(std::move(fragments_gen_fut)); } /// \brief A node that scans a dataset @@ -83,14 +81,14 @@ AsyncGenerator> GetFragments(Dataset* dataset, /// /// The first io-task (listing) fetches the fragments from the dataset. This may be a /// simple iteration of paths or, if the dataset is described with wildcards, this may -/// involve I/O for listing and walking directory paths. There is one listing io-task per -/// dataset. +/// involve I/O for listing and walking directory paths. There is one listing io-task +/// per dataset. /// -/// Ths next step is to fetch the metadata for the fragment. For some formats (e.g. CSV) -/// this may be quite simple (get the size of the file). For other formats (e.g. parquet) -/// this is more involved and requires reading data. There is one metadata io-task per -/// fragment. The metadata io-task creates an AsyncGenerator from the -/// fragment. +/// Ths next step is to fetch the metadata for the fragment. For some formats (e.g. +/// CSV) this may be quite simple (get the size of the file). For other formats (e.g. +/// parquet) this is more involved and requires reading data. There is one metadata +/// io-task per fragment. The metadata io-task creates an AsyncGenerator +/// from the fragment. /// /// Once the metadata io-task is done we can issue read io-tasks. Each read io-task /// requests a single batch of data from the disk by pulling the next Future from the @@ -100,9 +98,9 @@ AsyncGenerator> GetFragments(Dataset* dataset, /// through the pipeline. /// /// Most of these tasks are io-tasks. They take very few CPU resources and they run on -/// the I/O thread pool. These io-tasks are invisible to the exec plan and so we need to -/// do some custom scheduling. We limit how many fragments we read from at any one time. -/// This is referred to as "fragment readahead". +/// the I/O thread pool. These io-tasks are invisible to the exec plan and so we need +/// to do some custom scheduling. We limit how many fragments we read from at any one +/// time. This is referred to as "fragment readahead". /// /// Within a fragment there is usually also some amount of "row readahead". This row /// readahead is handled by the fragment (and not the scanner) because the exact details @@ -146,12 +144,18 @@ class ScanNode : public cp::ExecNode { // function registry as the one in ctx so we just require it to be unbound // FIXME - Do we care if it was bound to a different function registry? return Status::Invalid("Scan filter must be unbound"); - } else if (!normalized.filter.IsBound()) { + } else { ARROW_ASSIGN_OR_RAISE(normalized.filter, normalized.filter.Bind(*options.dataset->schema(), ctx)); + ARROW_ASSIGN_OR_RAISE(normalized.filter, + compute::RemoveNamedRefs(std::move(normalized.filter))); } // Else we must have some simple filter like literal(true) which might be bound // but we don't care + if (normalized.filter.type()->id() != Type::BOOL) { + return Status::Invalid("A scan filter must be a boolean expression"); + } + return std::move(normalized); } @@ -190,9 +194,10 @@ class ScanNode : public cp::ExecNode { : node_(node), scan_(scan_state), batch_index_(batch_index) { int64_t cost = scan_state->fragment_scanner->EstimatedDataBytes(batch_index_); // It's possible, though probably a bad idea, for a single batch of a fragment - // to be larger than 2GiB. In that case, it doesn't matter much if we underestimate - // because the largest the throttle can be is 2GiB and thus we will be in "one batch - // at a time" mode anyways which is the best we can do in this case. + // to be larger than 2GiB. In that case, it doesn't matter much if we + // underestimate because the largest the throttle can be is 2GiB and thus we will + // be in "one batch at a time" mode anyways which is the best we can do in this + // case. cost_ = static_cast( std::min(cost, static_cast(std::numeric_limits::max()))); } @@ -231,8 +236,9 @@ class ScanNode : public cp::ExecNode { : node(node), fragment(std::move(fragment)) {} Result> operator()() override { - return fragment->InspectFragment().Then( - [this](const std::shared_ptr& inspected_fragment) { + return fragment + ->InspectFragment(node->options_.format_options, node->plan_->exec_context()) + .Then([this](const std::shared_ptr& inspected_fragment) { return BeginScan(inspected_fragment); }); } @@ -244,7 +250,9 @@ class ScanNode : public cp::ExecNode { node->options_.dataset->evolution_strategy()->GetStrategy( *node->options_.dataset, *fragment, *inspected_fragment); ARROW_RETURN_NOT_OK(InitFragmentScanRequest()); - return fragment->BeginScan(scan_state->scan_request, *inspected_fragment) + return fragment + ->BeginScan(scan_state->scan_request, *inspected_fragment, + node->options_.format_options, node->plan_->exec_context()) .Then([this](const std::shared_ptr& fragment_scanner) { return AddScanTasks(fragment_scanner); }); @@ -301,21 +309,11 @@ class ScanNode : public cp::ExecNode { std::unique_ptr scan_state = std::make_unique(); }; - Status StartProducing() override { - START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), - {{"node.kind", kind_name()}, - {"node.label", label()}, - {"node.output_schema", output_schema()->ToString()}, - {"node.detail", ToString()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); - batches_throttle_ = util::ThrottledAsyncTaskScheduler::Make( - plan_->async_scheduler(), options_.target_bytes_readahead + 1); - AsyncGenerator> frag_gen = - GetFragments(options_.dataset.get(), options_.filter); + void ScanFragments(const AsyncGenerator>& frag_gen) { std::shared_ptr fragment_tasks = util::MakeThrottledAsyncTaskGroup( - plan_->async_scheduler(), options_.fragment_readahead + 1, /*queue=*/nullptr, - [this]() { + plan_->async_scheduler(), options_.fragment_readahead + 1, + /*queue=*/nullptr, [this]() { outputs_[0]->InputFinished(this, num_batches_.load()); finished_.MarkFinished(); return Status::OK(); @@ -326,6 +324,23 @@ class ScanNode : public cp::ExecNode { fragment_tasks->AddTask(std::make_unique(this, fragment)); return Status::OK(); }); + } + + Status StartProducing() override { + START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), + {{"node.kind", kind_name()}, + {"node.label", label()}, + {"node.output_schema", output_schema()->ToString()}, + {"node.detail", ToString()}}); + END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); + batches_throttle_ = util::ThrottledAsyncTaskScheduler::Make( + plan_->async_scheduler(), options_.target_bytes_readahead + 1); + plan_->async_scheduler()->AddSimpleTask([this] { + return GetFragments(options_.dataset.get(), options_.filter) + .Then([this](const AsyncGenerator>& frag_gen) { + ScanFragments(frag_gen); + }); + }); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index b4baab5a09f..3ee64a2b158 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -68,14 +68,28 @@ std::vector ScanOptions::MaterializedFields() const { return fields; } -std::vector ScanV2Options::AllColumns(const Dataset& dataset) { - std::vector selection(dataset.schema()->num_fields()); - for (std::size_t i = 0; i < selection.size(); i++) { - selection[i] = {static_cast(i)}; +std::vector ScanV2Options::AllColumns(const Schema& dataset_schema) { + std::vector selection(dataset_schema.num_fields()); + for (int i = 0; i < dataset_schema.num_fields(); i++) { + selection[i] = {i}; } return selection; } +Status ScanV2Options::AddFieldsNeededForFilter(ScanV2Options* options) { + std::vector fields_referenced = FieldsInExpression(options->filter); + for (const auto& field : fields_referenced) { + // Note: this will fail if the field reference is ambiguous or the field doesn't + // exist in the dataset schema + ARROW_ASSIGN_OR_RAISE(auto field_path, field.FindOne(*options->dataset->schema())); + if (std::find(options->columns.begin(), options->columns.end(), field_path) == + options->columns.end()) { + options->columns.push_back(std::move(field_path)); + } + } + return Status::OK(); +} + namespace { class ScannerRecordBatchReader : public RecordBatchReader { public: @@ -424,7 +438,8 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( {"filter", compute::FilterNodeOptions{scan_options_->filter}}, {"augmented_project", compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, - {"sink", compute::SinkNodeOptions{&sink_gen, scan_options_->backpressure}}, + {"sink", compute::SinkNodeOptions{&sink_gen, /*schema=*/nullptr, + scan_options_->backpressure}}, }) .AddToPlan(plan.get())); diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 7d6dd0e1a8c..6cbcaa0fc50 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -178,6 +178,13 @@ struct ARROW_DS_EXPORT ScanV2Options : public compute::ExecNodeOptions { /// /// A single guarantee-aware filtering operation should generally be applied to all /// resulting batches. The scan node is not responsible for this. + /// + /// Fields that are referenced by the filter should be included in the `columns` vector. + /// The scan node will not automatically fetch fields referenced by the filter + /// expression. \see AddFieldsNeededForFilter + /// + /// If the filter references fields that are not included in `columns` this may or may + /// not be an error, depending on the format. compute::Expression filter = compute::literal(true); /// \brief The columns to scan @@ -243,10 +250,17 @@ struct ARROW_DS_EXPORT ScanV2Options : public compute::ExecNodeOptions { /// one fragment at a time. int32_t fragment_readahead = kDefaultFragmentReadahead; /// \brief Options specific to the file format - FragmentScanOptions* format_options; + const FragmentScanOptions* format_options = NULLPTR; /// \brief Utility method to get a selection representing all columns in a dataset - static std::vector AllColumns(const Dataset& dataset); + static std::vector AllColumns(const Schema& dataset_schema); + + /// \brief Utility method to add fields needed for the current filter + /// + /// This method adds any fields that are needed by `filter` which are not already + /// included in the list of columns. Any new fields added will be added to the end + /// in no particular order. + static Status AddFieldsNeededForFilter(ScanV2Options* options); }; /// \brief Describes a projection diff --git a/cpp/src/arrow/dataset/scanner_benchmark.cc b/cpp/src/arrow/dataset/scanner_benchmark.cc index 0184fcce192..922f5c5787c 100644 --- a/cpp/src/arrow/dataset/scanner_benchmark.cc +++ b/cpp/src/arrow/dataset/scanner_benchmark.cc @@ -249,7 +249,7 @@ const std::function>(size_t, si // specify the filter compute::Expression b_is_true = equal(field_ref("b"), literal(true)); options->filter = b_is_true; - options->columns = ScanV2Options::AllColumns(*dataset); + options->columns = ScanV2Options::AllColumns(*dataset->schema()); return options; }; @@ -314,8 +314,8 @@ static void ScanBenchmark_Customize(benchmark::internal::Benchmark* b) { b->UseRealTime(); } -BENCHMARK(MinimalEndToEndBench)->Apply(ScanBenchmark_Customize)->Iterations(10); -BENCHMARK(ScanOnlyBench)->Apply(ScanBenchmark_Customize)->Iterations(10); +BENCHMARK(MinimalEndToEndBench)->Apply(ScanBenchmark_Customize); +BENCHMARK(ScanOnlyBench)->Apply(ScanBenchmark_Customize); } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 1edebb1cbea..83da0a3daf8 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -243,14 +243,17 @@ struct MockFragment : public Fragment { return Status::Invalid("Not implemented because not needed by unit tests"); }; - Future> InspectFragment() override { + Future> InspectFragment( + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override { has_inspected_ = true; return inspected_future_; } Future> BeginScan( - const FragmentScanRequest& request, - const InspectedFragment& inspected_fragment) override { + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override { has_started_ = true; seen_request_ = request; return fragment_scanner_future_; @@ -525,7 +528,7 @@ class TestScannerBase : public ::testing::TestWithParam { compute::Declaration MakeScanNode(std::shared_ptr dataset) { ScanV2Options options(dataset); - options.columns = ScanV2Options::AllColumns(*dataset); + options.columns = ScanV2Options::AllColumns(*dataset->schema()); return compute::Declaration("scan2", options); } @@ -641,7 +644,7 @@ TEST(TestNewScanner, Backpressure) { // No readahead options.dataset = test_dataset; - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); options.fragment_readahead = 0; options.target_bytes_readahead = 0; CheckScannerBackpressure(test_dataset, options, 1, 1, @@ -650,7 +653,7 @@ TEST(TestNewScanner, Backpressure) { // Some readahead test_dataset = MakeTestDataset(kNumFragments, kNumBatchesPerFragment); options = ScanV2Options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); options.fragment_readahead = 4; // each batch should be 14Ki so 50Ki readahead should yield 3-at-a-time options.target_bytes_readahead = 50 * kRowsPerTestBatch; @@ -696,10 +699,10 @@ std::shared_ptr MakePartitionSkipDataset() { std::shared_ptr test_schema = ScannerTestSchema(); MockDatasetBuilder builder(test_schema); builder.AddFragment(test_schema, /*inspection=*/nullptr, - greater(field_ref("filterable"), literal(50))); + greater(field_ref({1}), literal(50))); builder.AddBatch(MakeTestBatch(0)); builder.AddFragment(test_schema, /*inspection=*/nullptr, - less_equal(field_ref("filterable"), literal(50))); + less_equal(field_ref({1}), literal(50))); builder.AddBatch(MakeTestBatch(1)); return builder.Finish(); } @@ -710,7 +713,7 @@ TEST(TestNewScanner, PartitionSkip) { test_dataset->DeliverBatchesInOrder(false); ScanV2Options options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); options.filter = greater(field_ref("filterable"), literal(75)); ASSERT_OK_AND_ASSIGN(std::vector> batches, @@ -721,7 +724,7 @@ TEST(TestNewScanner, PartitionSkip) { test_dataset = MakePartitionSkipDataset(); test_dataset->DeliverBatchesInOrder(false); options = ScanV2Options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); options.filter = less(field_ref("filterable"), literal(25)); ASSERT_OK_AND_ASSIGN(batches, compute::DeclarationToBatches({"scan2", options})); @@ -736,7 +739,7 @@ TEST(TestNewScanner, NoFragments) { std::shared_ptr test_dataset = builder.Finish(); ScanV2Options options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); ASSERT_OK_AND_ASSIGN(std::vector> batches, compute::DeclarationToBatches({"scan2", options})); ASSERT_EQ(0, batches.size()); @@ -751,7 +754,7 @@ TEST(TestNewScanner, EmptyFragment) { test_dataset->DeliverBatchesInOrder(false); ScanV2Options options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); ASSERT_OK_AND_ASSIGN(std::vector> batches, compute::DeclarationToBatches({"scan2", options})); ASSERT_EQ(0, batches.size()); @@ -769,7 +772,7 @@ TEST(TestNewScanner, EmptyBatch) { test_dataset->DeliverBatchesInOrder(false); ScanV2Options options(test_dataset); - options.columns = ScanV2Options::AllColumns(*test_dataset); + options.columns = ScanV2Options::AllColumns(*test_dataset->schema()); ASSERT_OK_AND_ASSIGN(std::vector> batches, compute::DeclarationToBatches({"scan2", options})); ASSERT_EQ(0, batches.size()); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 17065bfd7d2..e6244fb0dfc 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -32,6 +32,7 @@ #include #include "arrow/array.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" @@ -915,11 +916,553 @@ class FileFormatScanMixin : public FileFormatFixtureMixin, using FileFormatFixtureMixin::opts_; }; +template +class FileFormatFixtureMixinV2 : public ::testing::Test { + public: + constexpr static int64_t kBatchSize = 1UL << 12; + constexpr static int64_t kBatchRepetitions = 1 << 5; + + FileFormatFixtureMixinV2() + : format_(FormatHelper::MakeFormat()), + // Set dataset to nullptr, we will fill it in later when (if) we scan + opts_(std::make_shared(/*dataset=*/nullptr)) {} + + int64_t expected_batches() const { return kBatchRepetitions; } + int64_t expected_rows() const { return kBatchSize * kBatchRepetitions; } + + std::shared_ptr MakeFragment(const FileSource& source) { + EXPECT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(source)); + return fragment; + } + + std::shared_ptr MakeFragment(const FileSource& source, + compute::Expression partition_expression) { + EXPECT_OK_AND_ASSIGN(auto fragment, + format_->MakeFragment(source, partition_expression)); + return fragment; + } + + std::shared_ptr MakeBufferSource(RecordBatchReader* reader) { + EXPECT_OK_AND_ASSIGN(auto buffer, FormatHelper::Write(reader)); + return std::make_shared(std::move(buffer)); + } + + virtual std::shared_ptr GetRandomData( + std::shared_ptr schema) { + return MakeGeneratedRecordBatch(schema, kBatchSize, kBatchRepetitions); + } + + Result> GetFileSink() { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr buffer, + AllocateResizableBuffer(0)); + return std::make_shared(buffer); + } + + void SetDatasetSchema(std::vector> fields) { + dataset_schema_ = schema(std::move(fields)); + SetScanProjectionAllColumns(); + } + + void CheckDatasetSchemaSet() { + DCHECK_NE(dataset_schema_, nullptr) + << "call SetDatasetSchema before calling this method"; + } + + void SetScanFilter(compute::Expression filter) { + CheckDatasetSchemaSet(); + opts_->filter = std::move(filter); + } + + void SetScanProjection(std::vector selection) { + opts_->columns = std::move(selection); + } + + void SetScanProjectionRefs(std::vector selection) { + opts_->columns.clear(); + opts_->columns.reserve(selection.size()); + for (const auto& ref : selection) { + ASSERT_OK_AND_ASSIGN(FieldPath path, ref.FindOne(*dataset_schema_)); + opts_->columns.push_back(std::move(path)); + } + } + + void SetScanProjectionAllColumns() { + CheckDatasetSchemaSet(); + opts_->columns = ScanV2Options::AllColumns(*dataset_schema_); + } + + // Shared test cases + void AssertInspectFailure(const std::string& contents, StatusCode code, + const std::string& format_name) { + SCOPED_TRACE("Format: " + format_name + " File contents: " + contents); + constexpr auto file_name = "herp/derp"; + auto make_error_message = [&](const std::string& filename) { + return "Could not open " + format_name + " input source '" + filename + "':"; + }; + const auto buf = std::make_shared(contents); + Status status; + + // Inspecting a buffer fails + status = format_->Inspect(FileSource(buf)).status(); + EXPECT_EQ(code, status.code()); + EXPECT_THAT(status.ToString(), ::testing::HasSubstr(make_error_message(""))); + + ASSERT_OK_AND_EQ(false, format_->IsSupported(FileSource(buf))); + + // Inspecting a file fails + ASSERT_OK_AND_ASSIGN( + auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime, {fs::File(file_name)})); + status = format_->Inspect({file_name, fs}).status(); + EXPECT_EQ(code, status.code()); + EXPECT_THAT(status.ToString(), testing::HasSubstr(make_error_message("herp/derp"))); + + // Discovering a dataset containing the invalid file fails + fs::FileSelector s; + s.base_dir = "/"; + s.recursive = true; + FileSystemFactoryOptions options; + ASSERT_OK_AND_ASSIGN(auto factory, + FileSystemDatasetFactory::Make(fs, s, format_, options)); + status = factory->Finish().status(); + EXPECT_EQ(code, status.code()); + EXPECT_THAT( + status.ToString(), + ::testing::AllOf( + ::testing::HasSubstr(make_error_message("/herp/derp")), + ::testing::HasSubstr( + "Error creating dataset. Could not read schema from '/herp/derp':"), + ::testing::HasSubstr("Is this a '" + format_->type_name() + "' file?"))); + } + + void TestInspectFailureWithRelevantError(StatusCode code, + const std::string& format_name) { + const std::vector file_contents{"", "PAR0", "ASDFPAR1", "ARROW1"}; + for (const auto& contents : file_contents) { + AssertInspectFailure(contents, code, format_name); + } + } + + // Inspecting a file should yield the appropriate schema + void TestInspect() { + auto reader = GetRandomData(schema({field("f64", float64())})); + auto source = MakeBufferSource(reader.get()); + + ASSERT_OK_AND_ASSIGN(auto actual, format_->Inspect(*source.get())); + AssertSchemaEqual(*actual, *reader->schema(), /*check_metadata=*/false); + } + + void TestIsSupported() { + auto reader = GetRandomData(schema({field("f64", float64())})); + auto source = MakeBufferSource(reader.get()); + + bool supported = false; + + std::shared_ptr buf = std::make_shared(std::string_view("")); + ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf))); + ASSERT_EQ(supported, false); + + buf = std::make_shared(std::string_view("corrupted")); + ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf))); + ASSERT_EQ(supported, false); + + ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(*source)); + EXPECT_EQ(supported, true); + } + + std::shared_ptr WriteToBuffer( + std::shared_ptr schema, + std::shared_ptr options = nullptr) { + auto format = format_; + SetDatasetSchema(schema->fields()); + EXPECT_OK_AND_ASSIGN(auto sink, GetFileSink()); + if (!options) options = format->DefaultWriteOptions(); + + EXPECT_OK_AND_ASSIGN(auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime, {})); + EXPECT_OK_AND_ASSIGN(auto writer, + format->MakeWriter(sink, schema, options, {fs, ""})); + ARROW_EXPECT_OK(writer->Write(GetRandomData(schema).get())); + auto fut = writer->Finish(); + EXPECT_FINISHES(fut); + ARROW_EXPECT_OK(fut.status()); + EXPECT_OK_AND_ASSIGN(auto written, sink->Finish()); + return written; + } + + void TestWrite() { + auto reader = this->GetRandomData(schema({field("f64", float64())})); + auto expected = this->MakeBufferSource(reader.get()); + auto written = this->WriteToBuffer(reader->schema()); + AssertBufferEqual(*written, *expected->buffer()); + } + + void TestCountRows() { + auto options = std::make_shared(); + auto reader = this->GetRandomData(schema({field("f64", float64())})); + auto full_schema = schema({field("f64", float64()), field("part", int64())}); + auto source = this->MakeBufferSource(reader.get()); + + auto fragment = this->MakeFragment(*source); + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), + fragment->CountRows(literal(true), options)); + + fragment = this->MakeFragment(*source, equal(field_ref("part"), literal(2))); + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), + fragment->CountRows(literal(true), options)); + + auto predicate = equal(field_ref("part"), literal(1)); + ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(0), + fragment->CountRows(predicate, options)); + + predicate = equal(field_ref("part"), literal(2)); + ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), + fragment->CountRows(predicate, options)); + + predicate = equal(call("add", {field_ref("f64"), literal(3)}), literal(2)); + ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); + ASSERT_FINISHES_OK_AND_EQ(std::nullopt, fragment->CountRows(predicate, options)); + } + void TestFragmentEquals() { + auto options = std::make_shared(); + auto this_schema = schema({field("f64", float64())}); + auto other_schema = schema({field("f32", float32())}); + auto reader = this->GetRandomData(this_schema); + auto other_reader = this->GetRandomData(other_schema); + auto source = this->MakeBufferSource(reader.get()); + auto other_source = this->MakeBufferSource(other_reader.get()); + + auto fragment = this->MakeFragment(*source); + EXPECT_TRUE(fragment->Equals(*fragment)); + auto other = this->MakeFragment(*other_source); + EXPECT_FALSE(fragment->Equals(*other)); + } + + protected: + std::shared_ptr format_; + std::shared_ptr opts_; + std::shared_ptr dataset_schema_; +}; + +template +class FileFormatScanNodeMixin : public FileFormatFixtureMixinV2, + public ::testing::WithParamInterface { + public: + int64_t expected_batches() const { return GetParam().num_batches; } + int64_t expected_rows() const { return GetParam().expected_rows(); } + + // Override FileFormatFixtureMixin::GetRandomData to paramterize the # + // of batches and rows per batch + std::shared_ptr GetRandomData( + std::shared_ptr schema) override { + return MakeGeneratedRecordBatch(schema, GetParam().items_per_batch, + GetParam().num_batches); + } + + // Scan the fragment through the scanner. + Result> Scan(std::shared_ptr fragment, + bool add_filter_fields = true) { + opts_->dataset = + std::make_shared(dataset_schema_, FragmentVector{fragment}); + if (add_filter_fields) { + ARROW_RETURN_NOT_OK(ScanV2Options::AddFieldsNeededForFilter(opts_.get())); + } + opts_->format_options = GetFormatOptions(); + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr reader, + compute::DeclarationToReader(compute::Declaration("scan2", *opts_), + GetParam().use_threads)); + return reader; + } + + // Shared test cases + void TestScan() { + // Basic test to make sure we can scan data + auto random_data = GetRandomData(schema({field("f64", float64())})); + auto source = this->MakeBufferSource(random_data.get()); + + this->SetDatasetSchema(random_data->schema()->fields()); + auto fragment = this->MakeFragment(*source); + + int64_t row_count = 0; + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + } + ASSERT_EQ(row_count, GetParam().expected_rows()); + } + + // TestScanBatchSize is no longer relevant because batch size is an internal concern. + // Consumers should only really care about batch sizing at the sink. + + // Ensure file formats only return columns needed to fulfill filter/projection + void TestScanProjected() { + auto f32 = field("f32", float32()); + auto f64 = field("f64", float64()); + auto i32 = field("i32", int32()); + auto i64 = field("i64", int64()); + this->SetDatasetSchema({f64, i64, f32, i32}); + this->SetScanProjectionRefs({"f64"}); + this->SetScanFilter(equal(field_ref("i32"), literal(0))); + + // We expect f64 since it is asked for and i32 since it is needed for the filter + auto expected_schema = schema({f64, i32}); + + auto reader = this->GetRandomData(dataset_schema_); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + int64_t row_count = 0; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + ASSERT_THAT( + batch->schema()->fields(), + ::testing::UnorderedPointwise(PointeesEqual(), expected_schema->fields())) + << "EXPECTED:\n" + << expected_schema->ToString() << "\nACTUAL:\n" + << batch->schema()->ToString(); + } + + ASSERT_EQ(row_count, expected_rows()); + } + + void TestScanMissingFilterField() { + auto f32 = field("f32", float32()); + auto f64 = field("f64", float64()); + this->SetDatasetSchema({f32, f64}); + this->SetScanProjectionRefs({"f64"}); + this->SetScanFilter(equal(field_ref("f32"), literal(0))); + + auto reader = this->GetRandomData(dataset_schema_); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + // At the moment, all formats support this. CSV & JSON simply ignore + // the filter field entirely. Parquet filters with statistics which doesn't require + // loading columns. + // + // However, it seems valid that a format would reject this case as well. Perhaps it + // is not worth testing. + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + } + + void TestScanProjectedNested(bool fine_grained_selection = false) { + // "struct1": { + // "f32", + // "i32" + // } + // "struct2": { + // "f64", + // "i64", + // "struct1": { + // "f32", + // "i32" + // } + // } + auto f32 = field("f32", float32()); + auto f64 = field("f64", float64()); + auto i32 = field("i32", int32()); + auto i64 = field("i64", int64()); + auto struct1 = field("struct1", struct_({f32, i32})); + auto struct2 = field("struct2", struct_({f64, i64, struct1})); + this->SetDatasetSchema({struct1, struct2, f32, f64, i32, i64}); + this->SetScanProjectionRefs( + {".struct1.f32", ".struct2.struct1", ".struct2.struct1.f32"}); + this->SetScanFilter(greater_equal(field_ref(FieldRef("struct2", "i64")), literal(0))); + + std::shared_ptr physical_schema; + if (fine_grained_selection) { + // Some formats, like Parquet, let you pluck only a part of a complex type + physical_schema = schema( + {field("struct1", struct_({f32})), field("struct2", struct_({i64, struct1}))}); + } else { + // Otherwise, the entire top-level field is returned + physical_schema = schema({struct1, struct2}); + } + std::shared_ptr projected_schema = schema({ + field(".struct1.f32", float32()), + field(".struct2.struct1", struct1->type()), + field(".struct2.struct1.f32", float32()), + }); + + { + auto reader = this->GetRandomData(dataset_schema_); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + int64_t row_count = 0; + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + AssertSchemaEqual(*batch->schema(), *projected_schema, + /*check_metadata=*/false); + } + ASSERT_EQ(row_count, expected_rows()); + } + { + // File includes a duplicated name in struct2 + auto struct2_physical = field("struct2", struct_({f64, i64, struct1, i64})); + auto reader = + this->GetRandomData(schema({struct1, struct2_physical, f32, f64, i32, i64})); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("i64"), + scanner->Next().status()); + } + { + // File is missing a child in struct1 + auto struct1_physical = field("struct1", struct_({i32})); + auto reader = + this->GetRandomData(schema({struct1_physical, struct2, f32, f64, i32, i64})); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + physical_schema = schema({physical_schema->field(1)}); + + int64_t row_count = 0; + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + ASSERT_THAT( + batch->schema()->fields(), + ::testing::UnorderedPointwise(PointeesEqual(), physical_schema->fields())) + << "EXPECTED:\n" + << physical_schema->ToString() << "\nACTUAL:\n" + << batch->schema()->ToString(); + } + ASSERT_EQ(row_count, expected_rows()); + } + } + + void TestScanProjectedMissingCols() { + auto f32 = field("f32", float32()); + auto f64 = field("f64", float64()); + auto i32 = field("i32", int32()); + auto i64 = field("i64", int64()); + this->SetDatasetSchema({f64, i64, f32, i32}); + this->SetScanProjectionRefs({"f64", "i32"}); + this->SetScanFilter(equal(field_ref("i32"), literal(0))); + + auto data_without_i32 = this->GetRandomData(schema({f64, i64, f32})); + auto data_without_f64 = this->GetRandomData(schema({i64, f32, i32})); + auto data_with_all = this->GetRandomData(schema({f64, i64, f32, i32})); + + auto readers = {data_with_all.get(), data_without_i32.get(), data_without_f64.get()}; + for (auto reader : readers) { + SCOPED_TRACE(reader->schema()->ToString()); + auto source = this->MakeBufferSource(reader); + auto fragment = this->MakeFragment(*source); + + // in the case where a file doesn't contain a referenced field, we materialize it + // as nulls + std::shared_ptr expected_schema = schema({f64, i32}); + + int64_t row_count = 0; + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + ASSERT_THAT( + batch->schema()->fields(), + ::testing::UnorderedPointwise(PointeesEqual(), expected_schema->fields())) + << "EXPECTED:\n" + << expected_schema->ToString() << "\nACTUAL:\n" + << batch->schema()->ToString(); + } + ASSERT_EQ(row_count, expected_rows()); + } + } + + void TestScanWithDuplicateColumn() { + // A duplicate column is ignored if not requested. + auto i32 = field("i32", int32()); + auto i64 = field("i64", int64()); + this->SetDatasetSchema({i32, i32, i64}); + this->SetScanProjectionRefs({"i64"}); + auto expected_schema = schema({i64}); + auto reader = this->GetRandomData(dataset_schema_); + auto source = this->MakeBufferSource(reader.get()); + auto fragment = this->MakeFragment(*source); + + int64_t row_count = 0; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + AssertSchemaEqual(*batch->schema(), *expected_schema, + /*check_metadata=*/false); + } + + ASSERT_EQ(row_count, expected_rows()); + + // Duplicate columns ok if column selection uses paths + row_count = 0; + expected_schema = schema({i32, i32}); + this->SetScanProjection({{0}, {1}}); + ASSERT_OK_AND_ASSIGN(scanner, this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + AssertSchemaEqual(*batch->schema(), *expected_schema, + /*check_metadata=*/false); + } + + ASSERT_EQ(row_count, expected_rows()); + } + + void TestScanWithPushdownNulls() { + // Regression test for ARROW-15312 + auto i64 = field("i64", int64()); + this->SetDatasetSchema({i64}); + this->SetScanFilter(is_null(field_ref("i64"))); + + auto rb = RecordBatchFromJSON(schema({i64}), R"([ + [null], + [32] + ])"); + ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make({rb})); + auto source = this->MakeBufferSource(reader.get()); + + auto fragment = this->MakeFragment(*source); + int64_t row_count = 0; + ASSERT_OK_AND_ASSIGN(std::unique_ptr scanner, + this->Scan(fragment)); + for (auto maybe_batch : *scanner) { + ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch); + row_count += batch->num_rows(); + } + ASSERT_EQ(row_count, 1); + } + + protected: + virtual const FragmentScanOptions* GetFormatOptions() = 0; + + using FileFormatFixtureMixinV2::opts_; + using FileFormatFixtureMixinV2::dataset_schema_; +}; + /// \brief A dummy FileFormat implementation class DummyFileFormat : public FileFormat { public: explicit DummyFileFormat(std::shared_ptr schema = NULLPTR) - : schema_(std::move(schema)) {} + : FileFormat(/*default_fragment_scan_options=*/nullptr), + schema_(std::move(schema)) {} std::string type_name() const override { return "dummy"; } @@ -959,10 +1502,12 @@ class JSONRecordBatchFileFormat : public FileFormat { using SchemaResolver = std::function(const FileSource&)>; explicit JSONRecordBatchFileFormat(std::shared_ptr schema) - : resolver_([schema](const FileSource&) { return schema; }) {} + : FileFormat(/*default_fragment_scan_opts=*/nullptr), + resolver_([schema](const FileSource&) { return schema; }) {} explicit JSONRecordBatchFileFormat(SchemaResolver resolver) - : resolver_(std::move(resolver)) {} + : FileFormat(/*default_fragment_scan_opts=*/nullptr), + resolver_(std::move(resolver)) {} bool Equals(const FileFormat& other) const override { return this == &other; } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 526213ce888..9b3ec54d00d 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -497,5 +497,30 @@ typename Fut::SyncType RunSynchronously(FnOnce get_future, } } +/// \brief Potentially iterate an async generator serially (if use_threads is false) +/// \see IterateGenerator +/// +/// If `use_threads` is true, the global CPU executor will be used. Each call to +/// the iterator will simply wait until the next item is available. Tasks may run in +/// the background between calls. +/// +/// If `use_threads` is false, the calling thread only will be used. Each call to +/// the iterator will use the calling thread to do enough work to generate one item. +/// Tasks will be left in a queue until the next call and no work will be done between +/// calls. +template +Iterator IterateSynchronously( + FnOnce()>>(Executor*)> get_gen, bool use_threads) { + if (use_threads) { + auto maybe_gen = std::move(get_gen)(GetCpuThreadPool()); + if (!maybe_gen.ok()) { + return MakeErrorIterator(maybe_gen.status()); + } + return MakeGeneratorIterator(*maybe_gen); + } else { + return SerialExecutor::IterateGenerator(std::move(get_gen)); + } +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 799675c6fa9..bce07d6908a 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -398,6 +398,54 @@ TEST(SerialExecutor, FailingIteratorWithCleanup) { ASSERT_TRUE(follow_up_ran); } +TEST(SerialExecutor, IterateSynchronously) { + for (bool use_threads : {false, true}) { + FnOnce>(Executor*)> factory = [](Executor* executor) { + AsyncGenerator vector_gen = MakeVectorGenerator({1, 2, 3}); + return MakeTransferredGenerator(vector_gen, executor); + }; + + Iterator my_it = + IterateSynchronously(std::move(factory), use_threads); + ASSERT_EQ(TestInt(1), *my_it.Next()); + ASSERT_EQ(TestInt(2), *my_it.Next()); + ASSERT_EQ(TestInt(3), *my_it.Next()); + AssertIteratorExhausted(my_it); + } +} + +struct MockGeneratorFactory { + explicit MockGeneratorFactory(Executor** captured_executor) + : captured_executor(captured_executor) {} + + Result> operator()(Executor* executor) { + *captured_executor = executor; + return MakeEmptyGenerator(); + } + Executor** captured_executor; +}; + +TEST(SerialExecutor, IterateSynchronouslyFactoryFails) { + for (bool use_threads : {false, true}) { + FnOnce>(Executor*)> factory = [](Executor* executor) { + return Status::Invalid("XYZ"); + }; + + Iterator my_it = + IterateSynchronously(std::move(factory), use_threads); + ASSERT_RAISES(Invalid, my_it.Next()); + } +} + +TEST(SerialExecutor, IterateSynchronouslyUsesThreadsIfRequested) { + Executor* captured_executor; + MockGeneratorFactory gen_factory(&captured_executor); + IterateSynchronously(gen_factory, true); + ASSERT_EQ(internal::GetCpuThreadPool(), captured_executor); + IterateSynchronously(gen_factory, false); + ASSERT_NE(internal::GetCpuThreadPool(), captured_executor); +} + class TransferTest : public testing::Test { public: internal::Executor* executor() { return mock_executor.get(); }