Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,51 +328,58 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
public:
DatasetWritingSinkNodeConsumer(std::shared_ptr<Schema> schema,
std::unique_ptr<internal::DatasetWriter> dataset_writer,
FileSystemDatasetWriteOptions write_options)
: schema(std::move(schema)),
dataset_writer(std::move(dataset_writer)),
write_options(std::move(write_options)) {}
FileSystemDatasetWriteOptions write_options,
std::shared_ptr<util::AsyncToggle> backpressure_toggle)
: schema_(std::move(schema)),
dataset_writer_(std::move(dataset_writer)),
write_options_(std::move(write_options)),
backpressure_toggle_(std::move(backpressure_toggle)) {}

Status Consume(compute::ExecBatch batch) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
batch.ToRecordBatch(schema));
batch.ToRecordBatch(schema_));
return WriteNextBatch(std::move(record_batch), batch.guarantee);
}

Future<> Finish() {
RETURN_NOT_OK(task_group.AddTask([this] { return dataset_writer->Finish(); }));
return task_group.End();
RETURN_NOT_OK(task_group_.AddTask([this] { return dataset_writer_->Finish(); }));
return task_group_.End();
}

private:
Status WriteNextBatch(std::shared_ptr<RecordBatch> batch,
compute::Expression guarantee) {
ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch));
ARROW_ASSIGN_OR_RAISE(auto groups, write_options_.partitioning->Partition(batch));
batch.reset(); // drop to hopefully conserve memory

if (groups.batches.size() > static_cast<size_t>(write_options.max_partitions)) {
if (groups.batches.size() > static_cast<size_t>(write_options_.max_partitions)) {
return Status::Invalid("Fragment would be written into ", groups.batches.size(),
" partitions. This exceeds the maximum of ",
write_options.max_partitions);
write_options_.max_partitions);
}

for (std::size_t index = 0; index < groups.batches.size(); index++) {
auto partition_expression = and_(groups.expressions[index], guarantee);
auto next_batch = groups.batches[index];
ARROW_ASSIGN_OR_RAISE(std::string destination,
write_options.partitioning->Format(partition_expression));
RETURN_NOT_OK(task_group.AddTask([this, next_batch, destination] {
return dataset_writer->WriteRecordBatch(next_batch, destination);
write_options_.partitioning->Format(partition_expression));
RETURN_NOT_OK(task_group_.AddTask([this, next_batch, destination] {
Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination);
if (!has_room.is_finished() && backpressure_toggle_) {
backpressure_toggle_->Close();
return has_room.Then([this] { backpressure_toggle_->Open(); });
}
return has_room;
}));
}
return Status::OK();
}

std::shared_ptr<Schema> schema;
std::unique_ptr<internal::DatasetWriter> dataset_writer;
FileSystemDatasetWriteOptions write_options;

util::SerializedAsyncTaskGroup task_group;
std::shared_ptr<Schema> schema_;
std::unique_ptr<internal::DatasetWriter> dataset_writer_;
FileSystemDatasetWriteOptions write_options_;
std::shared_ptr<util::AsyncToggle> backpressure_toggle_;
util::SerializedAsyncTaskGroup task_group_;
};

} // namespace
Expand All @@ -398,16 +405,19 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
scanner->options()->projection.call()->options.get())
->field_names;
std::shared_ptr<Dataset> dataset = scanner->dataset();
std::shared_ptr<util::AsyncToggle> backpressure_toggle =
std::make_shared<util::AsyncToggle>();

RETURN_NOT_OK(
compute::Declaration::Sequence(
{
{"scan", ScanNodeOptions{dataset, scanner->options()}},
{"scan", ScanNodeOptions{dataset, scanner->options(), backpressure_toggle}},
{"filter", compute::FilterNodeOptions{scanner->options()->filter}},
{"project",
compute::ProjectNodeOptions{std::move(exprs), std::move(names)}},
{"write",
WriteNodeOptions{write_options, scanner->options()->projected_schema}},
WriteNodeOptions{write_options, scanner->options()->projected_schema,
backpressure_toggle}},
})
.AddToPlan(plan.get()));

Expand All @@ -426,14 +436,16 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
const WriteNodeOptions write_node_options =
checked_cast<const WriteNodeOptions&>(options);
const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options;
std::shared_ptr<Schema> schema = write_node_options.schema;
const std::shared_ptr<Schema>& schema = write_node_options.schema;
const std::shared_ptr<util::AsyncToggle>& backpressure_toggle =
write_node_options.backpressure_toggle;

ARROW_ASSIGN_OR_RAISE(auto dataset_writer,
internal::DatasetWriter::Make(write_options));

std::shared_ptr<DatasetWritingSinkNodeConsumer> consumer =
std::make_shared<DatasetWritingSinkNodeConsumer>(
std::move(schema), std::move(dataset_writer), write_options);
schema, std::move(dataset_writer), write_options, backpressure_toggle);

ARROW_ASSIGN_OR_RAISE(
auto node,
Expand Down
10 changes: 7 additions & 3 deletions cpp/src/arrow/dataset/file_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,16 @@ struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions {
/// \brief Wraps FileSystemDatasetWriteOptions for consumption as compute::ExecNodeOptions
class ARROW_DS_EXPORT WriteNodeOptions : public compute::ExecNodeOptions {
public:
explicit WriteNodeOptions(FileSystemDatasetWriteOptions options,
std::shared_ptr<Schema> schema)
: write_options(std::move(options)), schema(std::move(schema)) {}
explicit WriteNodeOptions(
FileSystemDatasetWriteOptions options, std::shared_ptr<Schema> schema,
std::shared_ptr<util::AsyncToggle> backpressure_toggle = NULLPTR)
: write_options(std::move(options)),
schema(std::move(schema)),
backpressure_toggle(std::move(backpressure_toggle)) {}

FileSystemDatasetWriteOptions write_options;
std::shared_ptr<Schema> schema;
std::shared_ptr<util::AsyncToggle> backpressure_toggle;
};

/// @}
Expand Down
76 changes: 75 additions & 1 deletion python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import pickle
import textwrap
import tempfile
import threading
import time

import numpy as np
import pytest
Expand All @@ -30,7 +32,8 @@
import pyarrow.csv
import pyarrow.feather
import pyarrow.fs as fs
from pyarrow.tests.util import change_cwd, _filesystem_uri, FSProtocolClass
from pyarrow.tests.util import (change_cwd, _filesystem_uri,
FSProtocolClass, ProxyHandler)

try:
import pandas as pd
Expand Down Expand Up @@ -3389,6 +3392,77 @@ def test_write_dataset_with_scanner(tempdir):
) == table.drop(["a"]).to_pydict()


def test_write_dataset_with_backpressure(tempdir):
consumer_gate = threading.Event()

# A filesystem that blocks all writes so that we can build
# up backpressure. The writes are released at the end of
# the test.
class GatingFs(ProxyHandler):
def open_output_stream(self, path, metadata):
# Block until the end of the test
consumer_gate.wait()
return self._fs.open_output_stream(path, metadata=metadata)
gating_fs = fs.PyFileSystem(GatingFs(fs.LocalFileSystem()))

schema = pa.schema([pa.field('data', pa.int32())])
# By default, the dataset writer will queue up 64Mi rows so
# with batches of 1M it should only fit ~67 batches
batch = pa.record_batch([pa.array(list(range(1_000_000)))], schema=schema)
batches_read = 0
min_backpressure = 67
end = 200

def counting_generator():
nonlocal batches_read
while batches_read < end:
time.sleep(0.01)
batches_read += 1
yield batch

scanner = ds.Scanner.from_batches(
counting_generator(), schema=schema, use_threads=True,
use_async=True)

write_thread = threading.Thread(
target=lambda: ds.write_dataset(
scanner, str(tempdir), format='parquet', filesystem=gating_fs))
write_thread.start()

try:
start = time.time()

def duration():
return time.time() - start

# This test is timing dependent. There is no signal from the C++
# when backpressure has been hit. We don't know exactly when
# backpressure will be hit because it may take some time for the
# signal to get from the sink to the scanner.
#
# The test may emit false positives on slow systems. It could
# theoretically emit a false negative if the scanner managed to read
# and emit all 200 batches before the backpressure signal had a chance
# to propagate but the 0.01s delay in the generator should make that
# scenario unlikely.
Comment on lines +3438 to +3447
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment, If python test is slow why don't write this test in C++. I think there is more control in the C++, and even we a test with large workload is achivable, or run test cases when something so it doesn't always run,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are already tests in C++ for the scanner backpressure and the dataset writer backpressure. You are correct that we have more control. I was able to use the thread pool's "wait for idle" method to know when backpressure had been hit.

I wanted a python test to pull everything together and make sure it is actually being utilized correctly (I think it is easy sometimes for python to get missed due to a configuration parameter or something else). I'd be ok with removing this test but I don't think we need to add anything to C++. @bkietz thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this is sufficient for this PR

last_value = 0
backpressure_probably_hit = False
while duration() < 10:
if batches_read > min_backpressure:
if batches_read == last_value:
backpressure_probably_hit = True
break
last_value = batches_read
time.sleep(0.5)

assert backpressure_probably_hit

finally:
consumer_gate.set()
write_thread.join()
assert batches_read == end


def test_write_dataset_with_dataset(tempdir):
table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]})

Expand Down
63 changes: 1 addition & 62 deletions python/pyarrow/tests/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import pyarrow as pa
from pyarrow.tests.test_io import assert_file_not_found
from pyarrow.tests.util import _filesystem_uri
from pyarrow.tests.util import _filesystem_uri, ProxyHandler
from pyarrow.vendored.version import Version

from pyarrow.fs import (FileType, FileInfo, FileSelector, FileSystem,
Expand Down Expand Up @@ -143,67 +143,6 @@ def open_append_stream(self, path, metadata):
return pa.BufferOutputStream()


class ProxyHandler(FileSystemHandler):

def __init__(self, fs):
self._fs = fs

def __eq__(self, other):
if isinstance(other, ProxyHandler):
return self._fs == other._fs
return NotImplemented

def __ne__(self, other):
if isinstance(other, ProxyHandler):
return self._fs != other._fs
return NotImplemented

def get_type_name(self):
return "proxy::" + self._fs.type_name

def normalize_path(self, path):
return self._fs.normalize_path(path)

def get_file_info(self, paths):
return self._fs.get_file_info(paths)

def get_file_info_selector(self, selector):
return self._fs.get_file_info(selector)

def create_dir(self, path, recursive):
return self._fs.create_dir(path, recursive=recursive)

def delete_dir(self, path):
return self._fs.delete_dir(path)

def delete_dir_contents(self, path):
return self._fs.delete_dir_contents(path)

def delete_root_dir_contents(self):
return self._fs.delete_dir_contents("", accept_root_dir=True)

def delete_file(self, path):
return self._fs.delete_file(path)

def move(self, src, dest):
return self._fs.move(src, dest)

def copy_file(self, src, dest):
return self._fs.copy_file(src, dest)

def open_input_stream(self, path):
return self._fs.open_input_stream(path)

def open_input_file(self, path):
return self._fs.open_input_file(path)

def open_output_stream(self, path, metadata):
return self._fs.open_output_stream(path, metadata=metadata)

def open_append_stream(self, path, metadata):
return self._fs.open_append_stream(path, metadata=metadata)


@pytest.fixture
def localfs(request, tempdir):
return dict(
Expand Down
66 changes: 66 additions & 0 deletions python/pyarrow/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import pytest

import pyarrow as pa
import pyarrow.fs


def randsign():
Expand Down Expand Up @@ -251,6 +252,71 @@ def __fspath__(self):
return str(self._path)


class ProxyHandler(pyarrow.fs.FileSystemHandler):
"""
A dataset handler that proxies to an underlying filesystem. Useful
to partially wrap an existing filesystem with partial changes.
"""

def __init__(self, fs):
self._fs = fs

def __eq__(self, other):
if isinstance(other, ProxyHandler):
return self._fs == other._fs
return NotImplemented

def __ne__(self, other):
if isinstance(other, ProxyHandler):
return self._fs != other._fs
return NotImplemented

def get_type_name(self):
return "proxy::" + self._fs.type_name

def normalize_path(self, path):
return self._fs.normalize_path(path)

def get_file_info(self, paths):
return self._fs.get_file_info(paths)

def get_file_info_selector(self, selector):
return self._fs.get_file_info(selector)

def create_dir(self, path, recursive):
return self._fs.create_dir(path, recursive=recursive)

def delete_dir(self, path):
return self._fs.delete_dir(path)

def delete_dir_contents(self, path):
return self._fs.delete_dir_contents(path)

def delete_root_dir_contents(self):
return self._fs.delete_dir_contents("", accept_root_dir=True)

def delete_file(self, path):
return self._fs.delete_file(path)

def move(self, src, dest):
return self._fs.move(src, dest)

def copy_file(self, src, dest):
return self._fs.copy_file(src, dest)

def open_input_stream(self, path):
return self._fs.open_input_stream(path)

def open_input_file(self, path):
return self._fs.open_input_file(path)

def open_output_stream(self, path, metadata):
return self._fs.open_output_stream(path, metadata=metadata)

def open_append_stream(self, path, metadata):
return self._fs.open_append_stream(path, metadata=metadata)


def get_raise_signal():
if sys.version_info >= (3, 8):
return signal.raise_signal
Expand Down