diff --git a/python/pyarrow/parquet/__init__.py b/python/pyarrow/parquet/__init__.py index 967d39d3db0..a9437aec05d 100644 --- a/python/pyarrow/parquet/__init__.py +++ b/python/pyarrow/parquet/__init__.py @@ -3125,7 +3125,6 @@ def file_visitor(written_file): "implementation." ) metadata_collector = kwargs.pop('metadata_collector', None) - file_visitor = None if metadata_collector is not None: def file_visitor(written_file): metadata_collector.append(written_file.metadata) @@ -3140,15 +3139,15 @@ def file_visitor(written_file): if filesystem is not None: filesystem = _ensure_filesystem(filesystem) - partitioning = None if partition_cols: part_schema = table.select(partition_cols).schema partitioning = ds.partitioning(part_schema, flavor="hive") if basename_template is None: basename_template = guid() + '-{i}.parquet' - if existing_data_behavior is None: - existing_data_behavior = 'overwrite_or_ignore' + + if existing_data_behavior is None: + existing_data_behavior = 'overwrite_or_ignore' ds.write_dataset( table, root_path, filesystem=filesystem, diff --git a/python/pyarrow/tests/parquet/test_dataset.py b/python/pyarrow/tests/parquet/test_dataset.py index 1bcaf4b9e45..bbf27a98a3b 100644 --- a/python/pyarrow/tests/parquet/test_dataset.py +++ b/python/pyarrow/tests/parquet/test_dataset.py @@ -17,6 +17,7 @@ import datetime import os +import pathlib import numpy as np import pytest @@ -1782,3 +1783,29 @@ def test_parquet_write_to_dataset_unsupported_keywards_in_legacy(tempdir): with pytest.raises(ValueError, match="existing_data_behavior"): pq.write_to_dataset(table, path, use_legacy_dataset=True, existing_data_behavior='error') + + +@pytest.mark.dataset +def test_parquet_write_to_dataset_exposed_keywords(tempdir): + table = pa.table({'a': [1, 2, 3]}) + path = tempdir / 'partitioning' + + paths_written = [] + + def file_visitor(written_file): + paths_written.append(written_file.path) + + basename_template = 'part-{i}.parquet' + + pq.write_to_dataset(table, path, partitioning=["a"], + file_visitor=file_visitor, + basename_template=basename_template, + use_legacy_dataset=False) + + expected_paths = { + path / '1' / 'part-0.parquet', + path / '2' / 'part-0.parquet', + path / '3' / 'part-0.parquet' + } + paths_written_set = set(map(pathlib.Path, paths_written)) + assert paths_written_set == expected_paths