diff --git a/dask_expr/_align.py b/dask_expr/_align.py index ea28b8fed..a1a34700b 100644 --- a/dask_expr/_align.py +++ b/dask_expr/_align.py @@ -2,7 +2,7 @@ from tlz import merge_sorted, unique -from dask_expr._expr import Expr, Projection, is_broadcastable +from dask_expr._expr import Expr, Projection, is_broadcastable, plain_column_projection from dask_expr._repartition import RepartitionDivisions @@ -25,9 +25,9 @@ def _divisions(self): divisions = (divisions[0], divisions[0]) return divisions - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) def _lower(self): if not self.dfs: diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 5039953ba..c9acbd885 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -251,15 +251,15 @@ def __bool__(self): "Use a.any() or a.all()." ) - def persist(self, fuse=True, combine_similar=True, **kwargs): - out = self.optimize(combine_similar=combine_similar, fuse=fuse) + def persist(self, fuse=True, **kwargs): + out = self.optimize(fuse=fuse) return DaskMethodsMixin.persist(out, **kwargs) - def compute(self, fuse=True, combine_similar=True, **kwargs): + def compute(self, fuse=True, **kwargs): out = self if not isinstance(out, Scalar): out = out.repartition(npartitions=1) - out = out.optimize(combine_similar=combine_similar, fuse=fuse) + out = out.optimize(fuse=fuse) return DaskMethodsMixin.compute(out, **kwargs) def __dask_graph__(self): @@ -278,10 +278,8 @@ def simplify(self): def lower_once(self): return new_collection(self.expr.lower_once()) - def optimize(self, combine_similar: bool = True, fuse: bool = True): - return new_collection( - self.expr.optimize(combine_similar=combine_similar, fuse=fuse) - ) + def optimize(self, fuse: bool = True): + return new_collection(self.expr.optimize(fuse=fuse)) @property def dask(self): diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index d0ae12f94..26b884fd5 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -8,7 +8,14 @@ from dask.dataframe.utils import check_meta, strip_unknown_categories from dask.utils import apply, is_dataframe_like, is_series_like -from dask_expr._expr import AsType, Blockwise, Expr, Projection, are_co_aligned +from dask_expr._expr import ( + AsType, + Blockwise, + Expr, + Projection, + are_co_aligned, + determine_column_projection, +) class Concat(Expr): @@ -139,13 +146,13 @@ def _lower(self): *cast_dfs, ) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): def get_columns_or_name(e: Expr): return e.columns if e.ndim == 2 else [e.name] - columns = parent.columns + columns = determine_column_projection(self, parent, dependents) columns_frame = [ [col for col in get_columns_or_name(frame) if col in columns] for frame in self._frames diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 08d21e88d..802ec57e0 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -2,6 +2,8 @@ import functools import os +import weakref +from collections import defaultdict from collections.abc import Generator import dask @@ -10,7 +12,7 @@ from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like from dask.utils import funcname, import_required, is_arraylike -from dask_expr._util import _BackendData, _tokenize_deterministic, _tokenize_partial +from dask_expr._util import _BackendData, _tokenize_deterministic class Expr: @@ -245,25 +247,79 @@ def rewrite(self, kind: str): return expr - def simplify(self): + def simplify_once(self, dependents: defaultdict): """Simplify an expression This leverages the ``._simplify_down`` and ``._simplify_up`` methods defined on each class + Parameters + ---------- + + dependents: defaultdict[list] + The dependents for every node. + Returns ------- expr: output expression - changed: - whether or not any change occured """ - return self.rewrite(kind="simplify") + expr = self + + while True: + out = expr._simplify_down() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out._name != expr._name: + expr = out + + # Allow children to simplify their parents + for child in expr.dependencies(): + out = child._simplify_up(expr, dependents) + if out is None: + out = expr + + if not isinstance(out, Expr): + return out + if out is not expr and out._name != expr._name: + expr = out + break + + # Rewrite all of the children + new_operands = [] + changed = False + for operand in expr.operands: + if isinstance(operand, Expr): + new = operand.simplify_once(dependents=dependents) + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) + + if changed: + expr = type(expr)(*new_operands) + + break + + return expr + + def simplify(self) -> Expr: + expr = self + while True: + dependents = collect_depdendents(expr) + new = expr.simplify_once(dependents=dependents) + if new._name == expr._name: + break + expr = new + return expr def _simplify_down(self): return - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): return def lower_once(self): @@ -322,42 +378,6 @@ def lower_completely(self) -> Expr: def _lower(self): return - def _remove_operations(self, frame, remove_ops, skip_ops=None): - """Searches for operations that we have to push up again to avoid - the duplication of branches that are doing the same. - - Parameters - ---------- - frame: Expression that we will search. - remove_ops: Ops that we will remove to push up again. - skip_ops: Ops that were introduced and that we want to ignore. - - Returns - ------- - tuple of the new expression and the operations that we removed. - """ - - operations, ops_to_push_up = [], [] - frame_base = frame - combined_ops = remove_ops if skip_ops is None else remove_ops + skip_ops - while isinstance(frame, combined_ops): - # Have to respect ops that were injected while lowering or filters - if isinstance(frame, remove_ops): - ops_to_push_up.append(frame.operands[1]) - frame = frame.frame - break - else: - operations.append((type(frame), frame.operands[1:])) - frame = frame.frame - - if len(ops_to_push_up) > 0: - # Remove the projections but build the remaining things back up - for op_type, operands in reversed(operations): - frame = op_type(frame, *operands) - return frame, ops_to_push_up - else: - return frame_base, [] - @functools.cached_property def _name(self): return ( @@ -509,21 +529,6 @@ def substitute_parameters(self, substitutions: dict) -> Expr: return type(self)(*new_operands) return self - def _find_similar_operations(self, root: Expr, ignore: list | None = None): - # Find operations with the same type and operands. - # Parameter keys specified by `ignore` will not be - # included in the operand comparison - alike = [ - op for op in root.find_operations(type(self)) if op._name != self._name - ] - if not alike: - # No other operations of the same type. Early return - return [] - - # Return subset of `alike` with the same "token" - token = _tokenize_partial(self, ignore) - return [item for item in alike if _tokenize_partial(item, ignore) == token] - def _node_label_args(self): """Operands to include in the node label by `visualize`""" return self.dependencies() @@ -668,3 +673,19 @@ def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: or issubclass(operation, Expr) ), "`operation` must be`Expr` subclass)" return (expr for expr in self.walk() if isinstance(expr, operation)) + + +def collect_depdendents(expr) -> defaultdict: + dependents = defaultdict(list) + stack = [expr] + seen = set() + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + for dep in node.dependencies(): + stack.append(dep) + dependents[dep._name].append(weakref.ref(node)) + return dependents diff --git a/dask_expr/_cumulative.py b/dask_expr/_cumulative.py index d9df34091..73ba1fdf5 100644 --- a/dask_expr/_cumulative.py +++ b/dask_expr/_cumulative.py @@ -3,7 +3,7 @@ from dask.dataframe import methods from dask.utils import M -from dask_expr._expr import Blockwise, Expr, Projection +from dask_expr._expr import Blockwise, Expr, Projection, plain_column_projection class CumulativeAggregations(Expr): @@ -27,9 +27,9 @@ def _lower(self): chunks_last = TakeLast(chunks, self.skipna) return CumulativeFinalize(chunks, chunks_last, self.aggregate_operation) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) class CumulativeBlockwise(Blockwise): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 8671c05ef..6b655b3d3 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -32,6 +32,7 @@ from dask_expr import _core as core from dask_expr._util import ( _calc_maybe_new_divisions, + _convert_to_list, _tokenize_deterministic, _tokenize_partial, is_scalar, @@ -93,122 +94,6 @@ def __getattr__(self, key): f"API function. Current API coverage is documented here: {link}." ) - def combine_similar( - self, root: Expr | None = None, _cache: dict | None = None - ) -> Expr: - """Combine similar expression nodes using global information - - This leverages the ``._combine_similar`` method defined - on each class. The global expression-tree traversal will - change IO leaves first, and finish with the root expression. - The primary purpose of this method is to allow column - projections to be "pushed back up" the expression graph - in the case that simlar IO & Blockwise operations can - be captured by the same operations. - - Parameters - ---------- - root: - The root node of the global expression graph. If not - specified, the root is assumed to be ``self``. - _cache: - Optional dictionary to use for caching. - - Returns - ------- - expr: - output expression - """ - expr = self - update_root = root is None - root = root if root is not None else self - - if _cache is None: - _cache = {} - elif (self._name, root._name) in _cache: - return _cache[(self._name, root._name)] - - seen = set() - while expr._name not in seen: - # Call combine_similar on each dependency - new_operands = [] - changed_dependency = False - for operand in expr.operands: - if isinstance(operand, Expr): - new = operand.combine_similar(root=root, _cache=_cache) - if new._name != operand._name: - changed_dependency = True - else: - new = operand - new_operands.append(new) - - if changed_dependency: - expr = type(expr)(*new_operands) - if isinstance(expr, Projection): - # We might introduce stacked Projections (merge for example). - # So get rid of them here again - expr_simplify_down = expr._simplify_down() - if expr_simplify_down is not None: - expr = expr_simplify_down - if update_root: - root = expr - continue - - # Execute "_combine_similar" on expr - out = expr._combine_similar(root) - if out is None: - out = expr - if not isinstance(out, Expr): - _cache[(self._name, root._name)] = out - return out - seen.add(expr._name) - if expr._name != out._name and update_root: - root = expr - expr = out - - _cache[(self._name, root._name)] = expr - return expr - - def _combine_similar(self, root: Expr): - return - - def _combine_similar_branches(self, root, remove_ops, skip_ops=None): - # We have to go back until we reach an operation that was not pushed down - frame, operations = self._remove_operations(self.frame, remove_ops, skip_ops) - try: - common = type(self)(frame, *self.operands[1:]) - except ValueError: - # May have encountered a problem with `_required_attribute`. - # (There is no guarentee that the same method will exist for - # both a Series and DataFrame) - return None - - others = self._find_similar_operations(root, ignore=self._parameters) - - others_compatible = [] - for op in others: - if ( - isinstance(op.frame, remove_ops) - and (common._name == type(op)(op.frame.frame, *op.operands[1:])._name) - ) or common._name == op._name: - others_compatible.append(op) - - if isinstance(self.frame, Filter) and all( - isinstance(op.frame, Filter) for op in others_compatible - ): - # Avoid pushing filters up if all similar ops - # are acting on a Filter-based expression anyway - return None - - if len(others_compatible) > 0: - # Add operations back in the same order - for i, op in enumerate(reversed(operations)): - common = common[op] - if i > 0: - # Combine stacked projections - common = common._simplify_down() or common - return common - @property def index(self): return Index(self) @@ -488,6 +373,12 @@ def columns(self) -> list: if self.ndim == 1: return [self.name] return [] + except Exception: + raise + + @property + def _projection_columns(self): + return self.columns @property def name(self): @@ -615,21 +506,9 @@ def _task(self, index: int): else: return (self.operation,) + tuple(args) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if self._projection_passthrough and isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) - - def _combine_similar(self, root: Expr): - # Push projections back up through `_projection_passthrough` - # operations if it reduces the number of unique expression nodes. - if ( - self._projection_passthrough - and isinstance(self.frame, Projection) - or self._filter_passthrough - and isinstance(self.frame, Filter) - ): - return self._combine_similar_branches(root, (Filter, Projection)) - return None + return plain_column_projection(self, parent, dependents) class MapPartitions(Blockwise): @@ -942,14 +821,17 @@ class DropnaFrame(Blockwise): _keyword_only = ["how", "subset", "thresh"] operation = M.dropna - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if self.subset is not None: - columns = set(parent.columns).union(self.subset) - if columns == set(self.frame.columns): + columns = determine_column_projection( + self, parent, dependents, additional_columns=self.subset + ) + columns = [col for col in self.frame.columns if col in columns] + + if columns == self.frame.columns: # Don't add unnecessary Projections return - columns = [col for col in self.frame.columns if col in columns] return type(parent)( type(self)(self.frame[columns], *self.operands[1:]), *parent.operands[1:], @@ -969,19 +851,17 @@ def _meta(self): ), ) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - columns = parent.columns - frame_columns = sorted(set(columns).intersection(self.frame.columns)) - other_columns = sorted(set(columns).intersection(self.other.columns)) + columns = determine_column_projection(self, parent, dependents) + frame_columns = [col for col in self.frame.columns if col in columns] + other_columns = [col for col in self.other.columns if col in columns] if ( self.frame.columns == frame_columns and self.other.columns == other_columns ): return - frame_columns = [col for col in self.frame.columns if col in columns] - other_columns = [col for col in self.other.columns if col in columns] return type(parent)( type(self)(self.frame[frame_columns], self.other[other_columns]), *parent.operands[1:], @@ -1040,12 +920,15 @@ class Elemwise(Blockwise): _filter_passthrough = True _is_length_preserving = True - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if self._filter_passthrough and isinstance(parent, Filter): + parents = [x() for x in dependents[self._name] if x() is not None] + if not all(isinstance(p, Filter) for p in parents): + return return type(self)( self.frame[parent.operand("predicate")], *self.operands[1:] ) - return super()._simplify_up(parent) + return super()._simplify_up(parent, dependents) class RenameFrame(Elemwise): @@ -1053,20 +936,27 @@ class RenameFrame(Elemwise): _keyword_only = ["columns"] operation = M.rename - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection) and isinstance( self.operand("columns"), Mapping ): reverse_mapping = {val: key for key, val in self.operand("columns").items()} if is_series_like(parent._meta): - # Fill this out when Series.rename is implemented return - else: - columns = [ - reverse_mapping[col] if col in reverse_mapping else col - for col in parent.columns - ] - return type(self)(self.frame[columns], *self.operands[1:]) + + columns = determine_column_projection(self, parent, dependents) + columns = [ + reverse_mapping[col] if col in reverse_mapping else col + for col in columns + ] + columns = [col for col in self.frame.columns if col in columns] + if columns == self.frame.columns: + return + + return type(parent)( + type(self)(self.frame[columns], *self.operands[1:]), + *parent.operands[1:], + ) class RenameSeries(Elemwise): @@ -1131,12 +1021,9 @@ class Clip(Elemwise): _defaults = {"lower": None, "upper": None} operation = M.clip - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - if self.frame.columns == parent.columns: - # Don't introduce unnecessary projections - return - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) class Between(Elemwise): @@ -1208,16 +1095,22 @@ def _cat_dtype_without_categories(dtype): meta = clear_known_categories(meta) return meta - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): dtypes = self.operand("dtypes") + columns = determine_column_projection(self, parent, dependents) if isinstance(dtypes, dict): - dtypes = { - key: val for key, val in dtypes.items() if key in parent.columns - } + dtypes = {key: val for key, val in dtypes.items() if key in columns} if not dtypes: return type(parent)(self.frame, *parent.operands[1:]) - return type(self)(self.frame[parent.operand("columns")], dtypes) + if isinstance(columns, list): + columns = [col for col in self.frame.columns if col in columns] + if self.frame.columns == columns: + return + result = type(self)(self.frame[columns], dtypes) + if not isinstance(columns, list): + return result + return type(parent)(result, *parent.operands[1:]) class IsNa(Elemwise): @@ -1384,18 +1277,9 @@ class ExplodeSeries(Blockwise): class ExplodeFrame(ExplodeSeries): _parameters = ["frame", "column"] - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - columns = set(parent.columns).union(self.column) - if columns == set(self.frame.columns): - # Don't add unnecessary Projections, protects against loops - return - - columns = [col for col in self.frame.columns if col in columns] - return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), - *parent.operands[1:], - ) + return plain_column_projection(self, parent, dependents, [self.column]) class Drop(Elemwise): @@ -1426,12 +1310,13 @@ def _meta(self): def _node_label_args(self): return [self.frame, self.key, self.value] - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - if self.key not in parent.columns: + columns = determine_column_projection(self, parent, dependents) + if self.key not in columns: return type(parent)(self.frame, *parent.operands[1:]) - columns = set(parent.columns) - {self.key} + columns = set(columns) - {self.key} if columns == set(self.frame.columns): # Protect against pushing the same projection twice return @@ -1459,9 +1344,9 @@ class Filter(Blockwise): _parameters = ["frame", "predicate"] operation = operator.getitem - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return self.frame[parent.operand("columns")][self.predicate] + return plain_column_projection(self, parent, dependents) if isinstance(parent, Index): return self.frame.index[self.predicate] @@ -1547,6 +1432,10 @@ def _meta(self): return self.frame._meta.index return meta + @property + def _projection_columns(self): + return [] + def _task(self, index: int): return ( getattr, @@ -1625,11 +1514,14 @@ def _convert_columns(self, columns): len_prefix = len(self.prefix) return [col[len_prefix:] for col in columns] - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - columns = self._convert_columns(parent.columns) - if columns == self.frame.columns: + columns = determine_column_projection(self, parent, dependents) + columns = self._convert_columns(_convert_to_list(columns)) + if set(columns) == set(self.frame.columns): return + + columns = [col for col in self.frame.columns if col in columns] return type(parent)( type(self)(self.frame[columns], self.operands[1]), parent.operand("columns"), @@ -1671,7 +1563,7 @@ def _simplify_down(self): if isinstance(self.frame, Head): return Head(self.frame.frame, min(self.n, self.frame.n)) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): from dask_expr import Repartition if isinstance(parent, Repartition) and parent.new_partitions == 1: @@ -1730,7 +1622,7 @@ def _simplify_down(self): if isinstance(self.frame, Tail): return Tail(self.frame.frame, min(self.n, self.frame.n)) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): from dask_expr import Repartition if isinstance(parent, Repartition) and parent.new_partitions == 1: @@ -1774,19 +1666,34 @@ class Binop(Elemwise): def __str__(self): return f"{self.left} {self._operator_repr} {self.right}" - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - if isinstance(self.left, Expr) and self.left.ndim: - left = self.left[ - parent.operand("columns") - ] # TODO: filter just the correct columns + changed = False + columns = determine_column_projection(self, parent, dependents) + columns = _convert_to_list(columns) + columns = [col for col in self.columns if col in columns] + if ( + isinstance(self.left, Expr) + and self.left.ndim > 1 + and self.left.columns != columns + ): + left = self.left[columns] # TODO: filter just the correct columns + changed = True else: left = self.left - if isinstance(self.right, Expr) and self.right.ndim: - right = self.right[parent.operand("columns")] + if ( + isinstance(self.right, Expr) + and self.right.ndim > 1 + and self.right.columns != columns + ): + right = self.right[columns] # TODO: filter just the correct columns + changed = True else: right = self.right - return type(self)(left, right) + if not changed: + return + + return type(parent)(type(self)(left, right), *parent.operands[1:]) def _node_label_args(self): return [self.left, self.right] @@ -1898,12 +1805,10 @@ class Unaryop(Elemwise): def __str__(self): return f"{self._operator_repr} {self.frame}" - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): if isinstance(self.frame, Expr): - frame = self.frame[ - parent.operand("columns") - ] # TODO: filter just the correct columns + return plain_column_projection(self, parent, dependents) else: frame = self.frame return type(self)(frame) @@ -2033,39 +1938,30 @@ def normalize_expression(expr): return expr._name -def optimize(expr: Expr, combine_similar: bool = True, fuse: bool = True) -> Expr: +def optimize(expr: Expr, fuse: bool = True) -> Expr: """High level query optimization This leverages three optimization passes: 1. Class based simplification using the ``_simplify`` function and methods - 2. Combine similar operations - 3. Blockwise fusion + 2. Blockwise fusion Parameters ---------- expr: Input expression to optimize - combine_similar: - whether or not to combine similar operations - (like `ReadParquet`) to aggregate redundant work. fuse: whether or not to turn on blockwise fusion See Also -------- simplify - combine_similar optimize_blockwise_fusion """ # Simplify result = expr.simplify() - # Combine similar - if combine_similar: - result = result.combine_similar() - # Manipulate Expression to make it more efficient result = result.rewrite(kind="tune") @@ -2247,9 +2143,9 @@ def _divisions(self): def _meta(self): return meta_nonempty(self.frame._meta).diff(**self.kwargs) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) @functools.cached_property def kwargs(self): @@ -2279,9 +2175,9 @@ def _divisions(self): def _meta(self): return self.frame._meta - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) @functools.cached_property def kwargs(self): @@ -2342,9 +2238,9 @@ def _meta(self): def kwargs(self): return dict(periods=self.periods, freq=self.freq) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) @property def before(self): @@ -2488,6 +2384,56 @@ def _execute_task(graph, name, *deps): return dask.core.get(graph, name) +# Used for sorting with None +@functools.total_ordering +class MinType: + def __le__(self, other): + return True + + +def determine_column_projection(expr, parent, dependents, additional_columns=None): + if isinstance(parent, Index): + column_union = [] + else: + column_union = parent.columns.copy() + parents = [x() for x in dependents[expr._name] if x() is not None] + + for p in parents: + if len(p.columns) > 0: + column_union.append(p._projection_columns) + + if additional_columns is not None: + column_union.append(additional_columns) + + # We can end up with MultiIndex columns from groupby ops, needs to be + # accounted for in the sort + column_union = sorted( + set(flatten(column_union, container=list)), + key=lambda x: x[0] if isinstance(x, tuple) else x or MinType(), + ) + if ( + len(column_union) == 1 + and parent.ndim == 1 + and all(p.ndim == 1 for p in parents) + ): + return column_union[0] + return column_union + + +def plain_column_projection(expr, parent, dependents, additional_columns=None): + column_union = determine_column_projection( + expr, parent, dependents, additional_columns=additional_columns + ) + if column_union == expr.frame.columns: + return + if isinstance(column_union, list): + column_union = [col for col in expr.frame.columns if col in column_union] + result = type(expr)(expr.frame[column_union], *expr.operands[1:]) + if column_union == parent.operand("columns"): + return result + return type(parent)(result, parent.operand("columns")) + + from dask_expr._reductions import ( All, Any, diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 7bc9ea3f3..4b6970cde 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -50,6 +50,7 @@ RenameFrame, RenameSeries, are_co_aligned, + determine_column_projection, no_default, ) from dask_expr._reductions import ApplyConcatApply, Chunk, Reduction @@ -134,6 +135,10 @@ def split_out(self): return 1 return super().split_out + @property + def _projection_columns(self): + return self.frame.columns + def _tune_down(self): if len(self.by) > 1 and self.operand("split_out") is None: return self.substitute_parameters( @@ -230,16 +235,8 @@ def aggregate_kwargs(self) -> dict: **aggregate_kwargs, } - def _simplify_up(self, parent): - if isinstance(parent, Projection): - columns = sorted(set(parent.columns + self._by_columns)) - if columns == self.frame.columns: - return - columns = [col for col in self.frame.columns if col in columns] - return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), - *parent.operands[1:], - ) + def _simplify_up(self, parent, dependents): + return groupby_projection(self, parent, dependents) class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase): @@ -528,16 +525,8 @@ def _divisions(self): return (None, None) return (None,) * (self.split_out + 1) - def _simplify_up(self, parent): - if isinstance(parent, Projection): - columns = sorted(set(parent.columns + self._by_columns)) - if columns == self.frame.columns: - return - columns = [col for col in self.frame.columns if col in columns] - return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), - *parent.operands[1:], - ) + def _simplify_up(self, parent, dependents): + return groupby_projection(self, parent, dependents) class Std(SingleAggregation): @@ -763,17 +752,9 @@ def _fillna(group, *, what, **kwargs): class GroupByBFill(GroupByTransform): func = staticmethod(functools.partial(_fillna, what="bfill")) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - by_columns = self._by_columns - columns = sorted(set(parent.columns + by_columns)) - if columns == self.frame.columns: - return - columns = [col for col in self.frame.columns if col in columns] - return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), - *parent.operands[1:], - ) + return groupby_projection(self, parent, dependents) class GroupByFFill(GroupByBFill): @@ -805,16 +786,9 @@ def grp_func(self): def _shuffle_grp_func(self, shuffled=False): return self.grp_func - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - by_columns = self._by_columns - columns = sorted(set(parent.columns + by_columns)) - if columns == self.frame.columns: - return - return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), - *parent.operands[1:], - ) + return groupby_projection(self, parent, dependents) class GetGroup(Blockwise, GroupByBase): @@ -977,6 +951,21 @@ def _extract_meta(x, nonempty=False): return x +def groupby_projection(expr, parent, dependents): + if isinstance(parent, Projection): + columns = determine_column_projection( + expr, parent, dependents, additional_columns=expr._by_columns + ) + if columns == expr.frame.columns: + return + columns = [col for col in expr.frame.columns if col in columns] + return type(parent)( + type(expr)(expr.frame[columns], *expr.operands[1:]), + *parent.operands[1:], + ) + return + + ### ### Groupby Collection API ### diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 7b8249ded..3982db1af 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -2,7 +2,6 @@ import math import operator -from dask.core import flatten from dask.dataframe.dispatch import make_meta, meta_nonempty from dask.dataframe.multi import _concat_wrapper, _merge_chunk_wrapper, _split_partition from dask.dataframe.shuffle import partitioning_index @@ -12,19 +11,14 @@ from dask_expr._expr import ( Blockwise, Expr, - Filter, Index, PartitionsFiltered, Projection, + determine_column_projection, ) from dask_expr._repartition import Repartition -from dask_expr._shuffle import ( - AssignPartitioningIndex, - Shuffle, - _contains_index_name, - _select_columns_or_index, -) -from dask_expr._util import _convert_to_list, _tokenize_deterministic +from dask_expr._shuffle import Shuffle, _contains_index_name, _select_columns_or_index +from dask_expr._util import _convert_to_list, _tokenize_deterministic, is_scalar _HASH_COLUMN_NAME = "__hash_partition" _PARTITION_COLUMN = "_partitions" @@ -70,10 +64,6 @@ class Merge(Expr): "broadcast": None, } - # combine similar variables - _skip_ops = (Filter, AssignPartitioningIndex, Shuffle) - _remove_ops = (Projection,) - def __str__(self): return f"Merge({self._name[-7:]})" @@ -333,19 +323,19 @@ def _lower(self): # Blockwise merge return BlockwiseMerge(left, right, **self.kwargs) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, (Projection, Index)): # Reorder the column projection to # occur before the Merge + columns = determine_column_projection(self, parent, dependents) + columns = _convert_to_list(columns) if isinstance(parent, Index): # Index creates an empty column projection - projection, parent_columns = [], None + projection, parent_columns = columns, None else: - projection, parent_columns = parent.operand("columns"), parent.operand( - "columns" - ) - if isinstance(projection, (str, int)): - projection = [projection] + projection, parent_columns = columns, parent.operand("columns") + if is_scalar(projection): + projection = [projection] left, right = self.left, self.right left_on = _convert_to_list(self.left_on) @@ -391,137 +381,6 @@ def _simplify_up(self, parent): return type(parent)(result) return result[parent_columns] - def _validate_same_operations(self, common, op, remove="both"): - # Travers left and right to check if we can find the same operation - # more than once. We have to account for potential projections on both sides - name = common._name - if name == op._name: - return True, op.left.columns, op.right.columns - - columns_left, columns_right = None, None - op_left, op_right = op.left, op.right - if remove in ("both", "left"): - op_left, columns_left = self._remove_operations( - op.left, self._remove_ops, self._skip_ops - ) - if remove in ("both", "right"): - op_right, columns_right = self._remove_operations( - op.right, self._remove_ops, self._skip_ops - ) - - return ( - type(op)(op_left, op_right, *op.operands[2:])._name == name, - columns_left, - columns_right, - ) - - @staticmethod - def _flatten_columns(expr, columns, side): - if len(columns) == 0: - return getattr(expr, side).columns - else: - return list(set(flatten(columns))) - - def _combine_similar(self, root: Expr): - # Push projections back up to avoid performing the same merge multiple times - - left, columns_left = self._remove_operations( - self.left, self._remove_ops, self._skip_ops - ) - columns_left = self._flatten_columns(self, columns_left, "left") - right, columns_right = self._remove_operations( - self.right, self._remove_ops, self._skip_ops - ) - columns_right = self._flatten_columns(self, columns_right, "right") - - if left._name == self.left._name and right._name == self.right._name: - # There aren't any ops we can remove, so bail - return - - # We can not remove Projections on both sides at once, because only - # one side might need the push back up step. So try if removing Projections - # on either side works before removing them on both sides at once. - - common_left = type(self)(self.left, right, *self.operands[2:]) - common_right = type(self)(left, self.right, *self.operands[2:]) - common_both = type(self)(left, right, *self.operands[2:]) - - columns, left_sub, right_sub = None, None, None - - for op in self._find_similar_operations(root, ignore=["left", "right"]): - if op._name in (common_right._name, common_left._name, common_both._name): - if sorted(self.columns) != sorted(op.columns): - return op[self.columns] - return op - - validation = self._validate_same_operations(common_right, op, "left") - if validation[0]: - left_sub = self._flatten_columns(op, validation[1], side="left") - columns = self.right.columns.copy() - columns += [col for col in self.left.columns if col not in columns] - break - - validation = self._validate_same_operations(common_left, op, "right") - if validation[0]: - right_sub = self._flatten_columns(op, validation[2], side="right") - columns = self.left.columns.copy() - columns += [col for col in self.right.columns if col not in columns] - break - - validation = self._validate_same_operations(common_both, op) - if validation[0]: - left_sub = self._flatten_columns(op, validation[1], side="left") - right_sub = self._flatten_columns(op, validation[2], side="right") - columns = columns_left.copy() - columns += [col for col in columns_right if col not in columns_left] - break - - if columns is not None: - expr = self - if _PARTITION_COLUMN in columns: - columns.remove(_PARTITION_COLUMN) - - if left_sub is not None: - left_sub.extend([col for col in columns_left if col not in left_sub]) - left = self._replace_projections(self.left, sorted(left_sub)) - expr = expr.substitute(self.left, left) - - if right_sub is not None: - right_sub.extend([col for col in columns_right if col not in right_sub]) - right = self._replace_projections(self.right, sorted(right_sub)) - expr = expr.substitute(self.right, right) - - if sorted(expr.columns) != sorted(columns): - expr = expr[columns] - if expr._name == self._name: - return None - return expr - - def _replace_projections(self, frame, new_columns): - # This branch might have a number of Projections that differ from our - # new columns. We replace those projections appropriately - - operations = [] - while isinstance(frame, self._remove_ops + self._skip_ops): - if isinstance(frame, self._remove_ops): - # TODO: Shuffle and AssignPartitioningIndex being 2 different ops - # causes all kinds of pain - if isinstance(frame.frame, AssignPartitioningIndex): - new_cols = new_columns - else: - new_cols = [col for col in new_columns if col != _PARTITION_COLUMN] - - # Ignore Projection if new_columns = frame.frame.columns - if sorted(new_cols) != sorted(frame.frame.columns): - operations.append((type(frame), [new_cols])) - else: - operations.append((type(frame), frame.operands[1:])) - frame = frame.frame - - for op_type, operands in reversed(operations): - frame = op_type(frame, *operands) - return frame - class HashJoinP2P(Merge, PartitionsFiltered): _parameters = [ @@ -643,7 +502,7 @@ def _layer(self) -> dict: ) return dsk - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): return @@ -676,7 +535,7 @@ def _divisions(self): return self.right._divisions() return self.left._divisions() - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): return def _lower(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index f8cd11d51..3142bbf0f 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -30,6 +30,8 @@ RenameSeries, ResetIndex, ToFrame, + determine_column_projection, + plain_column_projection, ) from dask_expr._util import _tokenize_deterministic, is_scalar @@ -506,12 +508,6 @@ def aggregate(cls, inputs: list, **kwargs): df = _concat(inputs) return cls.aggregate_func(df, **kwargs) - def _simplify_up(self, parent): - return - - def __dask_postcompute__(self): - return _concat, () - class DropDuplicates(Unique): _parameters = ["frame", "subset", "ignore_index", "split_every", "split_out"] @@ -541,9 +537,11 @@ def chunk_kwargs(self): out["ignore_index"] = self.ignore_index return out - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if self.subset is not None and isinstance(parent, Projection): - columns = set(parent.columns).union(self.subset) + columns = determine_column_projection( + self, parent, dependents, additional_columns=self.subset + ) if columns == set(self.frame.columns): # Don't add unnecessary Projections, protects against loops return @@ -686,6 +684,10 @@ class Reduction(ApplyConcatApply): reduction_combine = None reduction_aggregate = None + @property + def _projection_columns(self): + return self.frame.columns + @classmethod def chunk(cls, df, **kwargs): out = cls.reduction_chunk(df, **kwargs) @@ -724,9 +726,9 @@ def __str__(self): base = "(" + base + ")" return f"{base}.{self.__class__.__name__.lower()}({s})" - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) class Sum(Reduction): @@ -864,7 +866,7 @@ def _simplify_down(self): if self.frame.ndim == 2 and len(self.frame.columns): return Len(self.frame.index) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): return @@ -881,7 +883,7 @@ def _simplify_down(self): else: return Len(self.frame) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): return @@ -1151,7 +1153,7 @@ def aggregate_kwargs(self): def combine_kwargs(self): return self.chunk_kwargs - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): # We are already a Series return diff --git a/dask_expr/_repartition.py b/dask_expr/_repartition.py index 99962f930..eb3588586 100644 --- a/dask_expr/_repartition.py +++ b/dask_expr/_repartition.py @@ -13,7 +13,7 @@ from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype from tlz import unique -from dask_expr._expr import Expr, Filter, Projection +from dask_expr._expr import Expr, Projection, plain_column_projection from dask_expr._reductions import TotalMemoryUsageFrame from dask_expr._util import LRU @@ -121,13 +121,9 @@ def _lower(self): else: raise NotImplementedError() - def _combine_similar(self, root: Expr): - return self._combine_similar_branches(root, (Filter, Projection)) - - def _simplify_up(self, parent): - # Reorder with column projection + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) @functools.cached_property def new_partitions(self): diff --git a/dask_expr/_resample.py b/dask_expr/_resample.py index 797b93133..591010e87 100644 --- a/dask_expr/_resample.py +++ b/dask_expr/_resample.py @@ -6,7 +6,13 @@ from dask.dataframe.tseries.resample import _resample_bin_and_out_divs, _resample_series from dask_expr._collection import new_collection -from dask_expr._expr import Blockwise, Expr, Projection, make_meta +from dask_expr._expr import ( + Blockwise, + Expr, + Projection, + make_meta, + plain_column_projection, +) from dask_expr._repartition import Repartition BlockwiseDep = namedtuple(typename="BlockwiseDep", field_names=["iterable"]) @@ -57,9 +63,9 @@ def _resample_divisions(self): self.frame.divisions, self.rule, **self.kwargs or {} ) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - return type(self)(self.frame[parent.operand("columns")], *self.operands[1:]) + return plain_column_projection(self, parent, dependents) def _lower(self): partitioned = Repartition( @@ -176,7 +182,7 @@ class ResampleSem(ResampleReduction): class ResampleAgg(ResampleReduction): how = "agg" - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): # Disable optimization in `agg`; function may access other columns return diff --git a/dask_expr/_rolling.py b/dask_expr/_rolling.py index 838be8903..0649ba5de 100644 --- a/dask_expr/_rolling.py +++ b/dask_expr/_rolling.py @@ -5,7 +5,14 @@ import pandas as pd from dask_expr._collection import new_collection -from dask_expr._expr import Blockwise, Expr, MapOverlap, Projection, make_meta +from dask_expr._expr import ( + Blockwise, + Expr, + MapOverlap, + Projection, + determine_column_projection, + make_meta, +) BlockwiseDep = namedtuple(typename="BlockwiseDep", field_names=["iterable"]) @@ -72,11 +79,12 @@ def _meta(self): def kwargs(self): return {} if self.operand("kwargs") is None else self.operand("kwargs") - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): by = self.groupby_kwargs.get("by", []) if self.groupby_kwargs else [] by_columns = by if not isinstance(by, Expr) else [] - columns = sorted(set(parent.columns + by_columns)) + columns = determine_column_projection(self, parent, dependents, by_columns) + columns = [col for col in self.frame.columns if col in columns] if columns == self.frame.columns: return if self.groupby_kwargs is not None: @@ -204,7 +212,7 @@ class RollingKurt(RollingReduction): class RollingAgg(RollingReduction): how = "agg" - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): # Disable optimization in `agg`; function may access other columns return diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 83ec5bc9f..76fc3c4d1 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -25,9 +25,9 @@ Assign, Blockwise, Expr, - Filter, PartitionsFiltered, Projection, + determine_column_projection, ) from dask_expr._reductions import ( All, @@ -117,13 +117,11 @@ def _lower(self): else: raise ValueError(f"{backend} not supported") - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): # Move the column projection to come # before the abstract Shuffle - projection = parent.operand("columns") - if isinstance(projection, (str, int)): - projection = [projection] + projection = determine_column_projection(self, parent, dependents) partitioning_index = self.partitioning_index if isinstance(partitioning_index, (str, int)): @@ -651,18 +649,6 @@ def operation(df, index, name: str, npartitions: int, assign_index): index = partitioning_index(index, npartitions) return df.assign(**{name: index}) - def _combine_similar(self, root: Expr): - return self._combine_similar_branches(root, (Filter, Projection)) - - def _remove_operations(self, frame, remove_ops, skip_ops=None): - expr, ops = super()._remove_operations(frame, remove_ops, skip_ops) - if len(ops) > 0 and isinstance(ops[0], list): - if sorted(ops[0]) == sorted(self.frame.columns): - expr, ops = super()._remove_operations(frame, remove_ops, skip_ops) - return expr, [] - ops[0] = ops[0] + [self.index_name] - return expr, ops - class BaseSetIndexSortValues(Expr): _is_length_preserving = True @@ -734,6 +720,12 @@ class SetIndex(BaseSetIndexSortValues): "upsample": 1.0, } + @property + def _projection_columns(self): + return self.columns + ( + [self._other] if not isinstance(self._other, Expr) else [] + ) + @functools.cached_property def _meta(self): if isinstance(self._other, Expr): @@ -786,7 +778,7 @@ def _lower(self): self.upsample, ) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): from dask_expr._expr import Filter, Head, Index, Tail # TODO, handle setting index with other frame @@ -813,9 +805,13 @@ def _simplify_up(self, parent): return SetIndex(tail, _other=self._other) if isinstance(parent, Projection): - columns = parent.columns + ( + addition_columns = ( [self._other] if not isinstance(self._other, Expr) else [] ) + columns = determine_column_projection( + self, parent, dependents, additional_columns=addition_columns + ) + columns = _convert_to_list(columns) if self.frame.columns == columns: return return type(parent)( @@ -937,7 +933,7 @@ def _lower(self): shuffled, self.sort_function, self.sort_function_kwargs ) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): from dask_expr._expr import Filter, Head, Tail if isinstance(parent, Head): @@ -970,12 +966,12 @@ def _simplify_up(self, parent): *self.operands[1:], ) if isinstance(parent, Projection): - parent_columns = parent.columns - columns = parent_columns + [ - col for col in self.by if col not in parent_columns - ] + columns = determine_column_projection( + self, parent, dependents, additional_columns=self.by + ) if self.frame.columns == columns: return + columns = [col for col in self.frame.columns if col in columns] return type(parent)( type(self)(self.frame[columns], *self.operands[1:]), parent.operand("columns"), @@ -1115,13 +1111,17 @@ def _divisions(self): return (None,) * (self.frame.npartitions + 1) return tuple(self.new_divisions) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): - columns = parent.columns + ( - _convert_to_list(self.other) if not isinstance(self.other, Expr) else [] + columns = determine_column_projection( + self, + parent, + dependents, + additional_columns=_convert_to_list(self.other), ) if self.frame.columns == columns: return + columns = [col for col in self.frame.columns if col in columns] return type(parent)( type(self)(self.frame[columns], *self.operands[1:]), parent.operand("columns"), diff --git a/dask_expr/_util.py b/dask_expr/_util.py index 96a125e4a..8c0161200 100644 --- a/dask_expr/_util.py +++ b/dask_expr/_util.py @@ -83,7 +83,7 @@ def is_scalar(x): return False if isinstance(x, dict): return False - if isinstance(x, (str, int)): + if isinstance(x, (str, int)) or x is None: return True from dask_expr._expr import Expr diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 0cc4db7f5..07904037c 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -17,6 +17,7 @@ Literal, PartitionsFiltered, Projection, + determine_column_projection, no_default, ) from dask_expr._reductions import Len @@ -59,7 +60,7 @@ class BlockwiseIO(Blockwise, IO): def _fusion_compression_factor(self): return 1 - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if ( self._absorb_projections and isinstance(parent, Projection) @@ -67,69 +68,17 @@ def _simplify_up(self, parent): ): # Column projection parent_columns = parent.operand("columns") - proposed_columns = _convert_to_list(parent_columns) - make_series = isinstance(parent_columns, (str, int)) and not self._series - if set(proposed_columns) == set(self.columns) and not make_series: - # Already projected + proposed_columns = determine_column_projection(self, parent, dependents) + proposed_columns = _convert_to_list(proposed_columns) + proposed_columns = [col for col in self.columns if col in proposed_columns] + if set(proposed_columns) == set(self.columns): + # Already projected or nothing to do return - substitutions = {"columns": _convert_to_list(parent_columns)} - if make_series: - substitutions["_series"] = True - return self.substitute_parameters(substitutions) - - def _combine_similar(self, root: Expr): - if self._absorb_projections: - # For BlockwiseIO expressions with "columns"/"_series" - # attributes (`_absorb_projections == True`), we can avoid - # redundant file-system access by aggregating multiple - # operations with different column projections into the - # same operation. - alike = self._find_similar_operations(root, ignore=["columns", "_series"]) - if alike: - # We have other BlockwiseIO operations (of the same - # sub-type) in the expression graph that can be combined - # with this one. - - # Find the column-projection union needed to combine - # the qualified BlockwiseIO operations - columns_operand = self.operand("columns") - if columns_operand is None: - columns_operand = self.columns - columns = set(columns_operand) - for op in alike: - op_columns = op.operand("columns") - if op_columns is None: - op_columns = op.columns - columns |= set(op_columns) - columns = sorted(columns) - if columns_operand is None: - columns_operand = self.columns - # Can bail if we are not changing columns or the "_series" operand - if columns_operand == columns and ( - len(columns) > 1 or not self._series - ): - return - - # Check if we have the operation we want elsewhere in the graph - for op in alike: - if set(op.columns) == set(columns) and not op.operand("_series"): - return ( - op[columns_operand[0]] - if self._series - else op[columns_operand] - ) - - if set(self.columns) == set(columns): - return # Skip unnecessary projection change - - # Create the "combined" ReadParquet operation - subs = {"columns": columns} - if self._series: - subs["_series"] = False - new = self.substitute_parameters(subs) - return new[columns_operand[0]] if self._series else new[columns_operand] - - return + substitutions = {"columns": proposed_columns} + result = self.substitute_parameters(substitutions) + if result.columns != parent_columns: + result = result[parent_columns] + return result def _tune_up(self, parent): if self._fusion_compression_factor >= 1: @@ -426,7 +375,7 @@ def _get_lengths(self) -> tuple | None: ) return self._pd_length_stats - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Lengths): _lengths = self._get_lengths() if _lengths: @@ -438,7 +387,7 @@ def _simplify_up(self, parent): return Literal(sum(_lengths)) if isinstance(parent, Projection): - return super()._simplify_up(parent) + return super()._simplify_up(parent, dependents) def _divisions(self): return self._divisions_and_locations[0] @@ -518,7 +467,7 @@ def _layer(self) -> dict: ) } - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): if sorted(parent.columns) == sorted(self.names): return diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 525478ab4..cce169be7 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -43,6 +43,7 @@ Literal, Or, Projection, + determine_column_projection, ) from dask_expr._reductions import Len from dask_expr._util import _convert_to_list @@ -450,13 +451,17 @@ def columns(self): else: return _convert_to_list(columns_operand) - def _simplify_up(self, parent): + def _simplify_up(self, parent, dependents): if isinstance(parent, Index): # Column projection - return self.substitute_parameters({"columns": [], "_series": False}) + columns = determine_column_projection(self, parent, dependents) + if set(columns) == set(self.columns): + return + columns = [col for col in self.columns if col in columns] + return self.substitute_parameters({"columns": columns, "_series": False}) if isinstance(parent, Projection): - return super()._simplify_up(parent) + return super()._simplify_up(parent, dependents) if isinstance(parent, Lengths): _lengths = self._get_lengths() diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 1ff3d0912..4fe702344 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -182,7 +182,10 @@ def test_io_fusion_blockwise(tmpdir): assert df.npartitions == 2 assert len(df.__dask_graph__()) == 2 graph = ( - read_parquet(tmpdir)["a"].repartition(npartitions=4).optimize().__dask_graph__() + read_parquet(tmpdir)["a"] + .repartition(npartitions=4) + .optimize(fuse=False) + .__dask_graph__() ) assert any("readparquet-fused" in key[0] for key in graph.keys()) diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 07cb106b1..3b8158d1b 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -674,7 +674,7 @@ def test_rolling_groupby_projection(): assert_eq(expected, actual, check_divisions=False) optimal = ( - ddf[["group1", "column1"]].groupby("group1").rolling("1D").sum()["column1"] + ddf[["column1", "group1"]].groupby("group1").rolling("1D").sum()["column1"] ) assert actual.optimize()._name == (optimal.optimize()._name) diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index e81d2cd5c..938acf341 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -167,7 +167,6 @@ def test_merge_len(): df = from_pandas(pdf, npartitions=2) pdf2 = lib.DataFrame({"x": [1, 2, 3], "z": 1}) df2 = from_pandas(pdf2, npartitions=2) - assert_eq(len(df.merge(df2)), len(pdf.merge(pdf2))) query = df.merge(df2).index.optimize(fuse=False) expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False) diff --git a/dask_expr/tests/test_reductions.py b/dask_expr/tests/test_reductions.py index bc2c69172..db18eadaf 100644 --- a/dask_expr/tests/test_reductions.py +++ b/dask_expr/tests/test_reductions.py @@ -106,9 +106,9 @@ def test_reductions_split_every_split_out(pdf, df, split_every, reduction): ) q = getattr(df.x, reduction)(split_every=split_every).optimize(fuse=False) if split_every is False: - assert len(q.__dask_graph__()) == 22 + assert len(q.__dask_graph__()) == 32 else: - assert len(q.__dask_graph__()) == 24 + assert len(q.__dask_graph__()) == 34 assert_eq( getattr(df, reduction)(split_every=split_every), getattr(pdf, reduction)(), diff --git a/dask_expr/tests/test_resample.py b/dask_expr/tests/test_resample.py index 31416328d..9f8183786 100644 --- a/dask_expr/tests/test_resample.py +++ b/dask_expr/tests/test_resample.py @@ -59,9 +59,11 @@ def test_resample_apis(df, pdf, api, kwargs): expected = getattr(pdf.resample("2T"), api)()["foo"] assert_eq(result, expected) - q = result.simplify() - eq = getattr(df["foo"].resample("2T"), api)().simplify() - assert q._name == eq._name + if api != "ohlc": + # ohlc actually gives back a DataFrame, so this doesn't work + q = result.simplify() + eq = getattr(df["foo"].resample("2T"), api)().simplify() + assert q._name == eq._name @pytest.mark.parametrize(