diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index de4dab3b9..2094e7df2 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -12,6 +12,7 @@ import dask.dataframe.methods as methods import numpy as np import pandas as pd +import pyarrow as pa from dask import compute, delayed from dask.array import Array from dask.base import DaskMethodsMixin, is_dask_collection, named_schedulers @@ -57,10 +58,12 @@ ) from dask.widgets import get_template from fsspec.utils import stringify_path +from packaging.version import parse as parse_version from pandas import CategoricalDtype from pandas.api.types import is_bool_dtype, is_datetime64_any_dtype, is_numeric_dtype from pandas.api.types import is_scalar as pd_is_scalar from pandas.api.types import is_timedelta64_dtype +from pyarrow import fs as pa_fs from tlz import first import dask_expr._backends # noqa: F401 @@ -4626,7 +4629,11 @@ def read_parquet( engine=None, **kwargs, ): - from dask_expr.io.parquet import ReadParquet, _set_parquet_engine + from dask_expr.io.parquet import ( + ReadParquetFSSpec, + ReadParquetPyarrowFS, + _set_parquet_engine, + ) if not isinstance(path, str): path = stringify_path(path) @@ -4639,8 +4646,61 @@ def read_parquet( if op == "in" and not isinstance(val, (set, list, tuple)): raise TypeError("Value of 'in' filter must be a list, set or tuple.") + if ( + isinstance(filesystem, pa_fs.FileSystem) + or isinstance(filesystem, str) + and filesystem.lower() in ("arrow", "pyarrow") + ): + if parse_version(pa.__version__) < parse_version("15.0.0"): + raise ValueError( + "pyarrow>=15.0.0 is required to use the pyarrow filesystem." + ) + if calculate_divisions: + raise NotImplementedError( + "calculate_divisions is not supported when using the pyarrow filesystem." + ) + if metadata_task_size is not None: + raise NotImplementedError( + "metadata_task_size is not supported when using the pyarrow filesystem." + ) + if split_row_groups != "infer": + raise NotImplementedError( + "split_row_groups is not supported when using the pyarrow filesystem." + ) + if blocksize is not None and blocksize != "default": + raise NotImplementedError( + "blocksize is not supported when using the pyarrow filesystem." + ) + if aggregate_files is not None: + raise NotImplementedError( + "aggregate_files is not supported when using the pyarrow filesystem." + ) + if parquet_file_extension != (".parq", ".parquet", ".pq"): + raise NotImplementedError( + "parquet_file_extension is not supported when using the pyarrow filesystem." + ) + if engine is not None: + raise NotImplementedError( + "engine is not supported when using the pyarrow filesystem." + ) + + return new_collection( + ReadParquetPyarrowFS( + path, + columns=_convert_to_list(columns), + filters=filters, + categories=categories, + index=index, + storage_options=storage_options, + filesystem=filesystem, + ignore_metadata_file=ignore_metadata_file, + kwargs=kwargs, + _series=isinstance(columns, str), + ) + ) + return new_collection( - ReadParquet( + ReadParquetFSSpec( path, columns=_convert_to_list(columns), filters=filters, diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3d..a3e0a21ec 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -4,12 +4,15 @@ import itertools import operator import warnings +from abc import abstractmethod from collections import defaultdict from functools import cached_property import dask +import pandas as pd import pyarrow as pa import pyarrow.dataset as pa_ds +import pyarrow.fs as pa_fs import pyarrow.parquet as pq import tlz as toolz from dask.base import normalize_token, tokenize @@ -39,17 +42,45 @@ And, Blockwise, Expr, + Filter, Index, Lengths, Literal, Or, Projection, + are_co_aligned, determine_column_projection, ) from dask_expr._reductions import Len from dask_expr._util import _convert_to_list, _tokenize_deterministic from dask_expr.io import BlockwiseIO, PartitionsFiltered + +@normalize_token.register(pa.fs.FileInfo) +def _tokenize_fileinfo(fileinfo): + return type(fileinfo).__name__, ( + fileinfo.path, + fileinfo.size, + fileinfo.mtime_ns, + fileinfo.size, + ) + + +PYARROW_NULLABLE_DTYPE_MAPPING = { + pa.int8(): pd.Int8Dtype(), + pa.int16(): pd.Int16Dtype(), + pa.int32(): pd.Int32Dtype(), + pa.int64(): pd.Int64Dtype(), + pa.uint8(): pd.UInt8Dtype(), + pa.uint16(): pd.UInt16Dtype(), + pa.uint32(): pd.UInt32Dtype(), + pa.uint64(): pd.UInt64Dtype(), + pa.bool_(): pd.BooleanDtype(), + pa.string(): pd.StringDtype(), + pa.float32(): pd.Float32Dtype(), + pa.float64(): pd.Float64Dtype(), +} + NONE_LABEL = "__null_dask_index__" _CACHED_PLAN_SIZE = 10 @@ -406,7 +437,330 @@ def to_parquet( return out +def _determine_type_mapper( + *, user_types_mapper=None, dtype_backend=None, convert_string=True +): + type_mappers = [] + + def pyarrow_type_mapper(pyarrow_dtype): + # Special case pyarrow strings to use more feature complete dtype + # See https://github.com/pandas-dev/pandas/issues/50074 + if pyarrow_dtype == pa.string(): + return pd.StringDtype("pyarrow") + else: + return pd.ArrowDtype(pyarrow_dtype) + + # always use the user-defined mapper first, if available + if user_types_mapper is not None: + type_mappers.append(user_types_mapper) + + # next in priority is converting strings + if convert_string: + type_mappers.append({pa.string(): pd.StringDtype("pyarrow")}.get) + type_mappers.append({pa.date32(): pd.ArrowDtype(pa.date32())}.get) + type_mappers.append({pa.date64(): pd.ArrowDtype(pa.date64())}.get) + + def _convert_decimal_type(type): + if pa.types.is_decimal(type): + return pd.ArrowDtype(type) + return None + + type_mappers.append(_convert_decimal_type) + + # and then nullable types + if dtype_backend == "numpy_nullable": + type_mappers.append(PYARROW_NULLABLE_DTYPE_MAPPING.get) + elif dtype_backend == "pyarrow": + type_mappers.append(pyarrow_type_mapper) + + def default_types_mapper(pyarrow_dtype): + """Try all type mappers in order, starting from the user type mapper.""" + for type_converter in type_mappers: + converted_type = type_converter(pyarrow_dtype) + if converted_type is not None: + return converted_type + + if len(type_mappers) > 0: + return default_types_mapper + + class ReadParquet(PartitionsFiltered, BlockwiseIO): + _pq_length_stats = None + _absorb_projections = True + _filter_passthrough = False + + def _filter_passthrough_available(self, parent, dependents): + return ( + super()._filter_passthrough_available(parent, dependents) + and (isinstance(parent.predicate, (LE, GE, LT, GT, EQ, NE, And, Or))) + and _DNF.extract_pq_filters(self, parent.predicate)._filters is not None + ) + + def _simplify_up(self, parent, dependents): + if isinstance(parent, Index): + # Column projection + columns = determine_column_projection(self, parent, dependents) + if set(columns) == set(self.columns): + return + columns = [col for col in self.columns if col in columns] + return self.substitute_parameters({"columns": columns, "_series": False}) + + if isinstance(parent, Projection): + return super()._simplify_up(parent, dependents) + + if isinstance(parent, Filter) and self._filter_passthrough_available( + parent, dependents + ): + # Predicate pushdown + filters = _DNF.extract_pq_filters(self, parent.predicate) + if filters._filters is not None: + return self.substitute_parameters( + { + "filters": filters.combine( + self.operand("filters") + ).to_list_tuple() + } + ) + + if isinstance(parent, Lengths): + _lengths = self._get_lengths() + if _lengths: + return Literal(_lengths) + + if isinstance(parent, Len): + _lengths = self._get_lengths() + if _lengths: + return Literal(sum(_lengths)) + + @property + def columns(self): + columns_operand = self.operand("columns") + if columns_operand is None: + return list(self._meta.columns) + else: + return _convert_to_list(columns_operand) + + @cached_property + def _name(self): + return ( + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(self.checksum, *self.operands[:-1]) + ) + + @property + def checksum(self): + return self._dataset_info["checksum"] + + def _tree_repr_argument_construction(self, i, op, header): + if self._parameters[i] == "_dataset_info_cache": + # Don't print this, very ugly + return header + return super()._tree_repr_argument_construction(i, op, header) + + @property + def _meta(self): + meta = self._dataset_info["base_meta"] + columns = _convert_to_list(self.operand("columns")) + if self._series: + assert len(columns) > 0 + return meta[columns[0]] + elif columns is not None: + return meta[columns] + return meta + + @abstractmethod + def _divisions(self): + raise NotImplementedError + + @property + def _fusion_compression_factor(self): + if self.operand("columns") is None: + return 1 + nr_original_columns = len(self._dataset_info["schema"].names) - 1 + return max( + len(_convert_to_list(self.operand("columns"))) / nr_original_columns, 0.001 + ) + + +class ReadParquetPyarrowFS(ReadParquet): + _parameters = [ + "path", + "columns", + "filters", + "categories", + "index", + "storage_options", + "filesystem", + "ignore_metadata_file", + "kwargs", + "_partitions", + "_series", + "_dataset_info_cache", + ] + _defaults = { + "columns": None, + "filters": None, + "categories": None, + "index": None, + "storage_options": None, + "filesystem": None, + "ignore_metadata_file": True, + "kwargs": None, + "_partitions": None, + "_series": False, + "_dataset_info_cache": None, + } + _pq_length_stats = None + _absorb_projections = True + _filter_passthrough = True + + @cached_property + def normalized_path(self): + return pa_fs.FileSystem.from_uri(self.path)[1] + + @cached_property + def fs(self): + fs_input = self.operand("filesystem") + if isinstance(fs_input, pa.fs.FileSystem): + return fs_input + else: + fs = pa_fs.FileSystem.from_uri(self.path)[0] + if storage_options := self.storage_options: + # Use inferred region as the default + region = {} if "region" in storage_options else {"region": fs.region} + fs = type(fs)(**region, **storage_options) + return fs + + @cached_property + def _dataset_info(self): + if rv := self.operand("_dataset_info_cache"): + return rv + dataset_info = {} + + path_normalized = self.normalized_path + # At this point we will post a couple of listbucket operations which + # includes the same data as a HEAD request. + # The information included here (see pyarrow FileInfo) are size, type, + # path and modified since timestamps + # This isn't free but realtively cheap (200-300ms or less for ~1k files) + finfo = self.fs.get_file_info(path_normalized) + if finfo.type == pa.fs.FileType.Directory: + dataset_selector = pa_fs.FileSelector(path_normalized, recursive=True) + all_files = [ + finfo + for finfo in self.fs.get_file_info(dataset_selector) + if finfo.type == pa.fs.FileType.File + ] + else: + all_files = [finfo] + # TODO: At this point we could verify if we're dealing with a very + # inhomogeneous datasets already without reading any further data + + metadata_file = False + checksum = None + dataset = None + if not self.ignore_metadata_file: + all_files = sorted( + all_files, key=lambda x: x.base_name.endswith("_metadata") + ) + if all_files[-1].base_name.endswith("_metadata"): + metadata_file = all_files.pop() + checksum = tokenize(metadata_file) + # TODO: dataset kwargs? + dataset = pa_ds.parquet_dataset( + metadata_file.path, + filesystem=self.fs, + ) + dataset_info["using_metadata_file"] = True + dataset_info["fragments"] = dataset.get_fragments() + if checksum is None: + checksum = tokenize(all_files) + dataset_info["checksum"] = checksum + if dataset is None: + import pyarrow.parquet as pq + + dataset = pq.ParquetDataset( + # TODO Just pass all_files once + # https://github.com/apache/arrow/pull/40143 is available to + # reduce latency + [fi.path for fi in all_files], + filesystem=self.fs, + filters=self.filters, + ) + dataset_info["using_metadata_file"] = False + dataset_info["fragments"] = dataset.fragments + + dataset_info["dataset"] = dataset + dataset_info["schema"] = dataset.schema + dataset_info["base_meta"] = dataset.schema.empty_table().to_pandas() + self.operands[ + type(self)._parameters.index("_dataset_info_cache") + ] = dataset_info + return dataset_info + + def _divisions(self): + return tuple([None] * (len(self.fragments) + 1)) + + @property + def fragments(self): + if self.filters is not None: + if self._dataset_info["using_metadata_file"]: + ds = self._dataset_info["dataset"] + else: + ds = self._dataset_info["dataset"]._dataset + return list(ds.get_fragments(filter=pq.filters_to_expression(self.filters))) + return self._dataset_info["fragments"] + + @staticmethod + def _fragment_to_pandas(fragment, columns, filters, schema): + from dask.utils import parse_bytes + + if isinstance(filters, list): + filters = pq.filters_to_expression(filters) + # TODO: There should be a way for users to define the type mapper + table = fragment.to_table( + schema=schema, + columns=columns, + filter=filters, + # Batch size determines how many rows are read at once and will + # cause the underlying array to be split into chunks of this size + # (max). We'd like to avoid fragmentation as much as possible and + # and to set this to something like inf but we have to set a finite, + # positive number. + # In the presence of row groups, the underlying array will still be + # chunked per rowgroup + batch_size=10_000_000, + # batch_readahead=16, + # fragment_readahead=4, + fragment_scan_options=pa.dataset.ParquetFragmentScanOptions( + pre_buffer=True, + cache_options=pa.CacheOptions( + hole_size_limit=parse_bytes("4 MiB"), + range_size_limit=parse_bytes("32.00 MiB"), + ), + ), + # TODO: Reconsider this. The OMP_NUM_THREAD variable makes it harmful to enable this + use_threads=True, + ) + df = table.to_pandas( + types_mapper=_determine_type_mapper(), + use_threads=False, + self_destruct=True, + ) + return df + + def _filtered_task(self, index: int): + return ( + ReadParquetPyarrowFS._fragment_to_pandas, + self.fragments[index], + self.columns, + self.filters, + self._dataset_info["schema"], + ) + + +class ReadParquetFSSpec(ReadParquet): """Read a parquet dataset""" _parameters = [ @@ -450,14 +804,6 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_series": False, "_dataset_info_cache": None, } - _pq_length_stats = None - _absorb_projections = True - - def _tree_repr_argument_construction(self, i, op, header): - if self._parameters[i] == "_dataset_info_cache": - # Don't print this, very ugly - return header - return super()._tree_repr_argument_construction(i, op, header) @property def engine(self): @@ -466,47 +812,8 @@ def engine(self): return get_engine(_engine) return _engine - @property - def columns(self): - columns_operand = self.operand("columns") - if columns_operand is None: - return list(self._meta.columns) - else: - return _convert_to_list(columns_operand) - - def _simplify_up(self, parent, dependents): - if isinstance(parent, Index): - # Column projection - columns = determine_column_projection(self, parent, dependents) - if set(columns) == set(self.columns): - return - columns = [col for col in self.columns if col in columns] - return self.substitute_parameters({"columns": columns, "_series": False}) - - if isinstance(parent, Projection): - return super()._simplify_up(parent, dependents) - - if isinstance(parent, Lengths): - _lengths = self._get_lengths() - if _lengths: - return Literal(_lengths) - - if isinstance(parent, Len): - _lengths = self._get_lengths() - if _lengths: - return Literal(sum(_lengths)) - - @cached_property - def _name(self): - return ( - funcname(type(self)).lower() - + "-" - + _tokenize_deterministic(self.checksum, *self.operands) - ) - - @property - def checksum(self): - return self._dataset_info["checksum"] + def _divisions(self): + return self._plan["divisions"] @property def _dataset_info(self): @@ -612,16 +919,11 @@ def _dataset_info(self): ] = dataset_info return dataset_info - @property - def _meta(self): - meta = self._dataset_info["base_meta"] - columns = _convert_to_list(self.operand("columns")) + def _filtered_task(self, index: int): + tsk = (self._io_func, self._plan["parts"][index]) if self._series: - assert len(columns) > 0 - return meta[columns[0]] - elif columns is not None: - return meta[columns] - return meta + return (operator.getitem, tsk, self.columns[0]) + return tsk @property def _io_func(self): @@ -678,15 +980,6 @@ def _plan(self): } return _cached_plan[dataset_token] - def _divisions(self): - return self._plan["divisions"] - - def _filtered_task(self, index: int): - tsk = (self._io_func, self._plan["parts"][index]) - if self._series: - return (operator.getitem, tsk, self.columns[0]) - return tsk - def _get_lengths(self) -> tuple | None: """Return known partition lengths using parquet statistics""" if not self.filters: @@ -696,6 +989,7 @@ def _get_lengths(self) -> tuple | None: for i, length in enumerate(self._pq_length_stats) if not self._filtered or i in self._partitions ) + return None def _update_length_statistics(self): """Ensure that partition-length statistics are up to date""" @@ -714,15 +1008,6 @@ def _update_length_statistics(self): stat["num-rows"] for stat in _collect_pq_statistics(self) ) - @property - def _fusion_compression_factor(self): - if self.operand("columns") is None: - return 1 - nr_original_columns = len(self._dataset_info["schema"].names) - 1 - return max( - len(_convert_to_list(self.operand("columns"))) / nr_original_columns, 0.001 - ) - # # Helper functions @@ -888,19 +1173,15 @@ def combine(self, other: _DNF | _And | _Or | list | tuple | None) -> _DNF: def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: _filters = None if isinstance(predicate_expr, (LE, GE, LT, GT, EQ, NE)): - if ( - isinstance(predicate_expr.left, ReadParquet) - and predicate_expr.left.path == pq_expr.path - and not isinstance(predicate_expr.right, Expr) + if are_co_aligned(pq_expr, predicate_expr.left) and not isinstance( + predicate_expr.right, Expr ): op = predicate_expr._operator_repr column = predicate_expr.left.columns[0] value = predicate_expr.right _filters = (column, op, value) - elif ( - isinstance(predicate_expr.right, ReadParquet) - and predicate_expr.right.path == pq_expr.path - and not isinstance(predicate_expr.left, Expr) + elif are_co_aligned(pq_expr, predicate_expr.right) and not isinstance( + predicate_expr.left, Expr ): # Simple dict to make sure field comes first in filter flip = {LE: GE, LT: GT, GE: LE, GT: LT} diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index cca5d63c0..a76bf13a0 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -21,8 +21,7 @@ read_csv, read_parquet, ) -from dask_expr._expr import Expr, Lengths, Literal, Replace -from dask_expr._reductions import Len +from dask_expr._expr import Expr, Replace from dask_expr.io import FromArray, FromMap, ReadCSV, ReadParquet, parquet from dask_expr.tests._util import _backend_library @@ -122,89 +121,11 @@ def test_read_csv_keywords(tmpdir): assert_eq(df, expected) -@pytest.mark.skip() -def test_predicate_pushdown(tmpdir): - original = pd.DataFrame( - { - "a": [1, 2, 3, 4, 5] * 10, - "b": [0, 1, 2, 3, 4] * 10, - "c": range(50), - "d": [6, 7] * 25, - "e": [8, 9] * 25, - } - ) - fn = _make_file(tmpdir, format="parquet", df=original) - df = read_parquet(fn) - assert_eq(df, original) - x = df[df.a == 5][df.c > 20]["b"] - y = optimize(x, fuse=False) - assert isinstance(y.expr, ReadParquet) - assert ("a", "==", 5) in y.expr.operand("filters")[0] - assert ("c", ">", 20) in y.expr.operand("filters")[0] - assert list(y.columns) == ["b"] - - # Check computed result - y_result = y.compute() - assert y_result.name == "b" - assert len(y_result) == 6 - assert (y_result == 4).all() - - -@pytest.mark.skip() -def test_predicate_pushdown_compound(tmpdir): - pdf = pd.DataFrame( - { - "a": [1, 2, 3, 4, 5] * 10, - "b": [0, 1, 2, 3, 4] * 10, - "c": range(50), - "d": [6, 7] * 25, - "e": [8, 9] * 25, - } - ) - fn = _make_file(tmpdir, format="parquet", df=pdf) - df = read_parquet(fn) - - # Test AND - x = df[(df.a == 5) & (df.c > 20)]["b"] - y = optimize(x, fuse=False) - assert isinstance(y.expr, ReadParquet) - assert {("c", ">", 20), ("a", "==", 5)} == set(y.filters[0]) - assert_eq( - y, - pdf[(pdf.a == 5) & (pdf.c > 20)]["b"], - check_index=False, - ) - - # Test OR - x = df[(df.a == 5) | (df.c > 20)] - x = x[x.b != 0]["b"] - y = optimize(x, fuse=False) - assert isinstance(y.expr, ReadParquet) - filters = [set(y.filters[0]), set(y.filters[1])] - assert {("c", ">", 20), ("b", "!=", 0)} in filters - assert {("a", "==", 5), ("b", "!=", 0)} in filters - expect = pdf[(pdf.a == 5) | (pdf.c > 20)] - expect = expect[expect.b != 0]["b"] - assert_eq( - y, - expect, - check_index=False, - ) - - # Test OR and AND - x = df[((df.a == 5) | (df.c > 20)) & (df.b != 0)]["b"] - z = optimize(x, fuse=False) - assert isinstance(z.expr, ReadParquet) - filters = [set(z.filters[0]), set(z.filters[1])] - assert {("c", ">", 20), ("b", "!=", 0)} in filters - assert {("a", "==", 5), ("b", "!=", 0)} in filters - assert_eq(y, z) - - def test_io_fusion_blockwise(tmpdir): pdf = pd.DataFrame({c: range(10) for c in "abcdefghijklmn"}) dd.from_pandas(pdf, 3).to_parquet(tmpdir) - df = read_parquet(tmpdir)["a"].fillna(10).optimize() + read_parq = read_parquet(tmpdir) + df = read_parq["a"].fillna(10).optimize() assert df.npartitions == 2 assert len(df.__dask_graph__()) == 2 graph = ( @@ -213,7 +134,9 @@ def test_io_fusion_blockwise(tmpdir): .optimize(fuse=False) .__dask_graph__() ) - assert any("readparquet-fused" in key[0] for key in graph.keys()) + assert any( + f"{read_parq._expr._name.split('-')[0]}-fused" in key[0] for key in graph.keys() + ) def test_repartition_io_fusion_blockwise(tmpdir): @@ -289,27 +212,6 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got.optimize(), expect) -def test_parquet_len(tmpdir): - df = read_parquet(_make_file(tmpdir)) - pdf = df.compute() - - assert len(df[df.a > 5]) == len(pdf[pdf.a > 5]) - - s = (df["b"] + 1).astype("Int32") - assert len(s) == len(pdf) - - assert isinstance(Len(s.expr).optimize(), Literal) - assert isinstance(Lengths(s.expr).optimize(), Literal) - - -def test_parquet_len_filter(tmpdir): - df = read_parquet(_make_file(tmpdir)) - expr = Len(df[df.c > 0].expr) - result = expr.simplify() - for rp in result.find_operations(ReadParquet): - assert rp.operand("columns") == ["c"] or rp.operand("columns") == [] - - @pytest.mark.parametrize("optimize", [True, False]) def test_from_dask_dataframe(optimize): ddf = dd.from_dict({"a": range(100)}, npartitions=10) @@ -336,35 +238,6 @@ def test_to_dask_array(optimize): array_assert_eq(darr, pdf.values) -@pytest.mark.parametrize("write_metadata_file", [True, False]) -def test_to_parquet(tmpdir, write_metadata_file): - pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) - df = from_pandas(pdf, npartitions=2) - - # Check basic parquet round trip - df.to_parquet(tmpdir, write_metadata_file=write_metadata_file) - df2 = read_parquet(tmpdir, calculate_divisions=True) - assert_eq(df, df2) - - # Check overwrite behavior - df["new"] = df["x"] + 1 - df.to_parquet(tmpdir, overwrite=True, write_metadata_file=write_metadata_file) - df2 = read_parquet(tmpdir, calculate_divisions=True) - assert_eq(df, df2) - - # Check that we cannot overwrite a path we are - # reading from in the same graph - with pytest.raises(ValueError, match="Cannot overwrite"): - df2.to_parquet(tmpdir, overwrite=True) - - -def test_to_parquet_engine(tmpdir): - pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) - df = from_pandas(pdf, npartitions=2) - with pytest.raises(NotImplementedError, match="not supported"): - df.to_parquet(tmpdir + "engine.parquet", engine="fastparquet") - - @pytest.mark.parametrize( "fmt,read_func,read_cls", [("parquet", read_parquet, ReadParquet), ("csv", read_csv, ReadCSV)], diff --git a/dask_expr/io/tests/test_parquet.py b/dask_expr/io/tests/test_parquet.py new file mode 100644 index 000000000..791a79dd4 --- /dev/null +++ b/dask_expr/io/tests/test_parquet.py @@ -0,0 +1,123 @@ +import os + +import pandas as pd +import pytest +from dask.dataframe.utils import assert_eq +from pyarrow import fs + +from dask_expr import from_pandas, read_parquet +from dask_expr._expr import Lengths, Literal +from dask_expr._reductions import Len +from dask_expr.io import ReadParquet + + +def _make_file(dir, df=None): + fn = os.path.join(str(dir), "myfile.parquet") + if df is None: + df = pd.DataFrame({c: range(10) for c in "abcde"}) + df.to_parquet(fn) + return fn + + +@pytest.fixture +def parquet_file(tmpdir): + return _make_file(tmpdir) + + +def test_parquet_len(tmpdir): + df = read_parquet(_make_file(tmpdir)) + pdf = df.compute() + + assert len(df[df.a > 5]) == len(pdf[pdf.a > 5]) + + s = (df["b"] + 1).astype("Int32") + assert len(s) == len(pdf) + + assert isinstance(Len(s.expr).optimize(), Literal) + assert isinstance(Lengths(s.expr).optimize(), Literal) + + +def test_parquet_len_filter(tmpdir): + df = read_parquet(_make_file(tmpdir)) + expr = Len(df[df.c > 0].expr) + result = expr.simplify() + for rp in result.find_operations(ReadParquet): + assert rp.operand("columns") == ["c"] or rp.operand("columns") == [] + + +@pytest.mark.parametrize("write_metadata_file", [True, False]) +def test_to_parquet(tmpdir, write_metadata_file): + pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) + df = from_pandas(pdf, npartitions=2) + + # Check basic parquet round trip + df.to_parquet(tmpdir, write_metadata_file=write_metadata_file) + df2 = read_parquet(tmpdir, calculate_divisions=True) + assert_eq(df, df2) + + # Check overwrite behavior + df["new"] = df["x"] + 1 + df.to_parquet(tmpdir, overwrite=True, write_metadata_file=write_metadata_file) + df2 = read_parquet(tmpdir, calculate_divisions=True) + assert_eq(df, df2) + + # Check that we cannot overwrite a path we are + # reading from in the same graph + with pytest.raises(ValueError, match="Cannot overwrite"): + df2.to_parquet(tmpdir, overwrite=True) + + +def test_to_parquet_engine(tmpdir): + pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) + df = from_pandas(pdf, npartitions=2) + with pytest.raises(NotImplementedError, match="not supported"): + df.to_parquet(tmpdir + "engine.parquet", engine="fastparquet") + + +def test_pyarrow_filesystem(parquet_file): + filesystem = fs.LocalFileSystem() + + df_pa = read_parquet(parquet_file, filesystem=filesystem) + df = read_parquet(parquet_file) + assert assert_eq(df, df_pa) + + +def test_pyarrow_filesystem_filters(parquet_file): + filesystem = fs.LocalFileSystem() + + df_pa = read_parquet(parquet_file, filesystem=filesystem) + df_pa = df_pa[df_pa.c == 1] + expected = read_parquet( + parquet_file, filesystem=filesystem, filters=[[("c", "==", 1)]] + ) + assert df_pa.optimize()._name == expected.optimize()._name + assert len(df_pa.compute()) == 1 + + +def test_partition_pruning(tmpdir): + filesystem = fs.LocalFileSystem() + df = from_pandas( + pd.DataFrame( + { + "a": [1, 2, 3, 4, 5] * 10, + "b": range(50), + } + ), + npartitions=2, + ) + df.to_parquet(tmpdir, partition_on=["a"]) + ddf = read_parquet(tmpdir, filesystem=filesystem) + ddf_filtered = read_parquet( + tmpdir, filters=[[("a", "==", 1)]], filesystem=filesystem + ) + assert ddf_filtered.npartitions == ddf.npartitions // 5 + + ddf_optimize = read_parquet(tmpdir, filesystem=filesystem) + ddf_optimize = ddf_optimize[ddf_optimize.a == 1].optimize() + assert ddf_optimize.npartitions == ddf.npartitions // 5 + assert_eq( + ddf_filtered, + ddf_optimize, + # FIXME ? + check_names=False, + )