diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index c6cf3649589..4ff3c6d2b4e 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -328,51 +328,58 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { public: DatasetWritingSinkNodeConsumer(std::shared_ptr schema, std::unique_ptr dataset_writer, - FileSystemDatasetWriteOptions write_options) - : schema(std::move(schema)), - dataset_writer(std::move(dataset_writer)), - write_options(std::move(write_options)) {} + FileSystemDatasetWriteOptions write_options, + std::shared_ptr backpressure_toggle) + : schema_(std::move(schema)), + dataset_writer_(std::move(dataset_writer)), + write_options_(std::move(write_options)), + backpressure_toggle_(std::move(backpressure_toggle)) {} Status Consume(compute::ExecBatch batch) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, - batch.ToRecordBatch(schema)); + batch.ToRecordBatch(schema_)); return WriteNextBatch(std::move(record_batch), batch.guarantee); } Future<> Finish() { - RETURN_NOT_OK(task_group.AddTask([this] { return dataset_writer->Finish(); })); - return task_group.End(); + RETURN_NOT_OK(task_group_.AddTask([this] { return dataset_writer_->Finish(); })); + return task_group_.End(); } private: Status WriteNextBatch(std::shared_ptr batch, compute::Expression guarantee) { - ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch)); + ARROW_ASSIGN_OR_RAISE(auto groups, write_options_.partitioning->Partition(batch)); batch.reset(); // drop to hopefully conserve memory - if (groups.batches.size() > static_cast(write_options.max_partitions)) { + if (groups.batches.size() > static_cast(write_options_.max_partitions)) { return Status::Invalid("Fragment would be written into ", groups.batches.size(), " partitions. This exceeds the maximum of ", - write_options.max_partitions); + write_options_.max_partitions); } for (std::size_t index = 0; index < groups.batches.size(); index++) { auto partition_expression = and_(groups.expressions[index], guarantee); auto next_batch = groups.batches[index]; ARROW_ASSIGN_OR_RAISE(std::string destination, - write_options.partitioning->Format(partition_expression)); - RETURN_NOT_OK(task_group.AddTask([this, next_batch, destination] { - return dataset_writer->WriteRecordBatch(next_batch, destination); + write_options_.partitioning->Format(partition_expression)); + RETURN_NOT_OK(task_group_.AddTask([this, next_batch, destination] { + Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination); + if (!has_room.is_finished() && backpressure_toggle_) { + backpressure_toggle_->Close(); + return has_room.Then([this] { backpressure_toggle_->Open(); }); + } + return has_room; })); } return Status::OK(); } - std::shared_ptr schema; - std::unique_ptr dataset_writer; - FileSystemDatasetWriteOptions write_options; - - util::SerializedAsyncTaskGroup task_group; + std::shared_ptr schema_; + std::unique_ptr dataset_writer_; + FileSystemDatasetWriteOptions write_options_; + std::shared_ptr backpressure_toggle_; + util::SerializedAsyncTaskGroup task_group_; }; } // namespace @@ -398,16 +405,19 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio scanner->options()->projection.call()->options.get()) ->field_names; std::shared_ptr dataset = scanner->dataset(); + std::shared_ptr backpressure_toggle = + std::make_shared(); RETURN_NOT_OK( compute::Declaration::Sequence( { - {"scan", ScanNodeOptions{dataset, scanner->options()}}, + {"scan", ScanNodeOptions{dataset, scanner->options(), backpressure_toggle}}, {"filter", compute::FilterNodeOptions{scanner->options()->filter}}, {"project", compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, {"write", - WriteNodeOptions{write_options, scanner->options()->projected_schema}}, + WriteNodeOptions{write_options, scanner->options()->projected_schema, + backpressure_toggle}}, }) .AddToPlan(plan.get())); @@ -426,14 +436,16 @@ Result MakeWriteNode(compute::ExecPlan* plan, const WriteNodeOptions write_node_options = checked_cast(options); const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options; - std::shared_ptr schema = write_node_options.schema; + const std::shared_ptr& schema = write_node_options.schema; + const std::shared_ptr& backpressure_toggle = + write_node_options.backpressure_toggle; ARROW_ASSIGN_OR_RAISE(auto dataset_writer, internal::DatasetWriter::Make(write_options)); std::shared_ptr consumer = std::make_shared( - std::move(schema), std::move(dataset_writer), write_options); + schema, std::move(dataset_writer), write_options, backpressure_toggle); ARROW_ASSIGN_OR_RAISE( auto node, diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index a645c2c8b08..3c7b8258963 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -410,12 +410,16 @@ struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions { /// \brief Wraps FileSystemDatasetWriteOptions for consumption as compute::ExecNodeOptions class ARROW_DS_EXPORT WriteNodeOptions : public compute::ExecNodeOptions { public: - explicit WriteNodeOptions(FileSystemDatasetWriteOptions options, - std::shared_ptr schema) - : write_options(std::move(options)), schema(std::move(schema)) {} + explicit WriteNodeOptions( + FileSystemDatasetWriteOptions options, std::shared_ptr schema, + std::shared_ptr backpressure_toggle = NULLPTR) + : write_options(std::move(options)), + schema(std::move(schema)), + backpressure_toggle(std::move(backpressure_toggle)) {} FileSystemDatasetWriteOptions write_options; std::shared_ptr schema; + std::shared_ptr backpressure_toggle; }; /// @} diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index f6db5c065d9..e5590c4a6bf 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -22,6 +22,8 @@ import pickle import textwrap import tempfile +import threading +import time import numpy as np import pytest @@ -30,7 +32,8 @@ import pyarrow.csv import pyarrow.feather import pyarrow.fs as fs -from pyarrow.tests.util import change_cwd, _filesystem_uri, FSProtocolClass +from pyarrow.tests.util import (change_cwd, _filesystem_uri, + FSProtocolClass, ProxyHandler) try: import pandas as pd @@ -3389,6 +3392,77 @@ def test_write_dataset_with_scanner(tempdir): ) == table.drop(["a"]).to_pydict() +def test_write_dataset_with_backpressure(tempdir): + consumer_gate = threading.Event() + + # A filesystem that blocks all writes so that we can build + # up backpressure. The writes are released at the end of + # the test. + class GatingFs(ProxyHandler): + def open_output_stream(self, path, metadata): + # Block until the end of the test + consumer_gate.wait() + return self._fs.open_output_stream(path, metadata=metadata) + gating_fs = fs.PyFileSystem(GatingFs(fs.LocalFileSystem())) + + schema = pa.schema([pa.field('data', pa.int32())]) + # By default, the dataset writer will queue up 64Mi rows so + # with batches of 1M it should only fit ~67 batches + batch = pa.record_batch([pa.array(list(range(1_000_000)))], schema=schema) + batches_read = 0 + min_backpressure = 67 + end = 200 + + def counting_generator(): + nonlocal batches_read + while batches_read < end: + time.sleep(0.01) + batches_read += 1 + yield batch + + scanner = ds.Scanner.from_batches( + counting_generator(), schema=schema, use_threads=True, + use_async=True) + + write_thread = threading.Thread( + target=lambda: ds.write_dataset( + scanner, str(tempdir), format='parquet', filesystem=gating_fs)) + write_thread.start() + + try: + start = time.time() + + def duration(): + return time.time() - start + + # This test is timing dependent. There is no signal from the C++ + # when backpressure has been hit. We don't know exactly when + # backpressure will be hit because it may take some time for the + # signal to get from the sink to the scanner. + # + # The test may emit false positives on slow systems. It could + # theoretically emit a false negative if the scanner managed to read + # and emit all 200 batches before the backpressure signal had a chance + # to propagate but the 0.01s delay in the generator should make that + # scenario unlikely. + last_value = 0 + backpressure_probably_hit = False + while duration() < 10: + if batches_read > min_backpressure: + if batches_read == last_value: + backpressure_probably_hit = True + break + last_value = batches_read + time.sleep(0.5) + + assert backpressure_probably_hit + + finally: + consumer_gate.set() + write_thread.join() + assert batches_read == end + + def test_write_dataset_with_dataset(tempdir): table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]}) diff --git a/python/pyarrow/tests/test_fs.py b/python/pyarrow/tests/test_fs.py index 684d89f5b0d..d4d8367d1e8 100644 --- a/python/pyarrow/tests/test_fs.py +++ b/python/pyarrow/tests/test_fs.py @@ -29,7 +29,7 @@ import pyarrow as pa from pyarrow.tests.test_io import assert_file_not_found -from pyarrow.tests.util import _filesystem_uri +from pyarrow.tests.util import _filesystem_uri, ProxyHandler from pyarrow.vendored.version import Version from pyarrow.fs import (FileType, FileInfo, FileSelector, FileSystem, @@ -143,67 +143,6 @@ def open_append_stream(self, path, metadata): return pa.BufferOutputStream() -class ProxyHandler(FileSystemHandler): - - def __init__(self, fs): - self._fs = fs - - def __eq__(self, other): - if isinstance(other, ProxyHandler): - return self._fs == other._fs - return NotImplemented - - def __ne__(self, other): - if isinstance(other, ProxyHandler): - return self._fs != other._fs - return NotImplemented - - def get_type_name(self): - return "proxy::" + self._fs.type_name - - def normalize_path(self, path): - return self._fs.normalize_path(path) - - def get_file_info(self, paths): - return self._fs.get_file_info(paths) - - def get_file_info_selector(self, selector): - return self._fs.get_file_info(selector) - - def create_dir(self, path, recursive): - return self._fs.create_dir(path, recursive=recursive) - - def delete_dir(self, path): - return self._fs.delete_dir(path) - - def delete_dir_contents(self, path): - return self._fs.delete_dir_contents(path) - - def delete_root_dir_contents(self): - return self._fs.delete_dir_contents("", accept_root_dir=True) - - def delete_file(self, path): - return self._fs.delete_file(path) - - def move(self, src, dest): - return self._fs.move(src, dest) - - def copy_file(self, src, dest): - return self._fs.copy_file(src, dest) - - def open_input_stream(self, path): - return self._fs.open_input_stream(path) - - def open_input_file(self, path): - return self._fs.open_input_file(path) - - def open_output_stream(self, path, metadata): - return self._fs.open_output_stream(path, metadata=metadata) - - def open_append_stream(self, path, metadata): - return self._fs.open_append_stream(path, metadata=metadata) - - @pytest.fixture def localfs(request, tempdir): return dict( diff --git a/python/pyarrow/tests/util.py b/python/pyarrow/tests/util.py index 558df8cf1b0..281de69e3e6 100644 --- a/python/pyarrow/tests/util.py +++ b/python/pyarrow/tests/util.py @@ -33,6 +33,7 @@ import pytest import pyarrow as pa +import pyarrow.fs def randsign(): @@ -251,6 +252,71 @@ def __fspath__(self): return str(self._path) +class ProxyHandler(pyarrow.fs.FileSystemHandler): + """ + A dataset handler that proxies to an underlying filesystem. Useful + to partially wrap an existing filesystem with partial changes. + """ + + def __init__(self, fs): + self._fs = fs + + def __eq__(self, other): + if isinstance(other, ProxyHandler): + return self._fs == other._fs + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ProxyHandler): + return self._fs != other._fs + return NotImplemented + + def get_type_name(self): + return "proxy::" + self._fs.type_name + + def normalize_path(self, path): + return self._fs.normalize_path(path) + + def get_file_info(self, paths): + return self._fs.get_file_info(paths) + + def get_file_info_selector(self, selector): + return self._fs.get_file_info(selector) + + def create_dir(self, path, recursive): + return self._fs.create_dir(path, recursive=recursive) + + def delete_dir(self, path): + return self._fs.delete_dir(path) + + def delete_dir_contents(self, path): + return self._fs.delete_dir_contents(path) + + def delete_root_dir_contents(self): + return self._fs.delete_dir_contents("", accept_root_dir=True) + + def delete_file(self, path): + return self._fs.delete_file(path) + + def move(self, src, dest): + return self._fs.move(src, dest) + + def copy_file(self, src, dest): + return self._fs.copy_file(src, dest) + + def open_input_stream(self, path): + return self._fs.open_input_stream(path) + + def open_input_file(self, path): + return self._fs.open_input_file(path) + + def open_output_stream(self, path, metadata): + return self._fs.open_output_stream(path, metadata=metadata) + + def open_append_stream(self, path, metadata): + return self._fs.open_append_stream(path, metadata=metadata) + + def get_raise_signal(): if sys.version_info >= (3, 8): return signal.raise_signal