Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dask_expr/_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tlz import merge_sorted, unique

from dask_expr._expr import Expr, Projection, is_broadcastable
from dask_expr._expr import Expr, Projection, is_broadcastable, plain_column_projection
from dask_expr._repartition import RepartitionDivisions


Expand All @@ -25,9 +25,9 @@ def _divisions(self):
divisions = (divisions[0], divisions[0])
return divisions

def _simplify_up(self, parent):
def _simplify_up(self, parent, dependents):
if isinstance(parent, Projection):
return type(self)(self.frame[parent.operand("columns")], *self.operands[1:])
return plain_column_projection(self, parent, dependents)

def _lower(self):
if not self.dfs:
Expand Down
14 changes: 6 additions & 8 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,15 @@ def __bool__(self):
"Use a.any() or a.all()."
)

def persist(self, fuse=True, combine_similar=True, **kwargs):
out = self.optimize(combine_similar=combine_similar, fuse=fuse)
def persist(self, fuse=True, **kwargs):
out = self.optimize(fuse=fuse)
return DaskMethodsMixin.persist(out, **kwargs)

def compute(self, fuse=True, combine_similar=True, **kwargs):
def compute(self, fuse=True, **kwargs):
out = self
if not isinstance(out, Scalar):
out = out.repartition(npartitions=1)
out = out.optimize(combine_similar=combine_similar, fuse=fuse)
out = out.optimize(fuse=fuse)
return DaskMethodsMixin.compute(out, **kwargs)

def __dask_graph__(self):
Expand All @@ -278,10 +278,8 @@ def simplify(self):
def lower_once(self):
return new_collection(self.expr.lower_once())

def optimize(self, combine_similar: bool = True, fuse: bool = True):
return new_collection(
self.expr.optimize(combine_similar=combine_similar, fuse=fuse)
)
def optimize(self, fuse: bool = True):
return new_collection(self.expr.optimize(fuse=fuse))

@property
def dask(self):
Expand Down
13 changes: 10 additions & 3 deletions dask_expr/_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from dask.dataframe.utils import check_meta, strip_unknown_categories
from dask.utils import apply, is_dataframe_like, is_series_like

from dask_expr._expr import AsType, Blockwise, Expr, Projection, are_co_aligned
from dask_expr._expr import (
AsType,
Blockwise,
Expr,
Projection,
are_co_aligned,
determine_column_projection,
)


class Concat(Expr):
Expand Down Expand Up @@ -139,13 +146,13 @@ def _lower(self):
*cast_dfs,
)

def _simplify_up(self, parent):
def _simplify_up(self, parent, dependents):
if isinstance(parent, Projection):

def get_columns_or_name(e: Expr):
return e.columns if e.ndim == 2 else [e.name]

columns = parent.columns
columns = determine_column_projection(self, parent, dependents)
columns_frame = [
[col for col in get_columns_or_name(frame) if col in columns]
for frame in self._frames
Expand Down
135 changes: 78 additions & 57 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import functools
import os
import weakref
from collections import defaultdict
from collections.abc import Generator

import dask
Expand All @@ -10,7 +12,7 @@
from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like
from dask.utils import funcname, import_required, is_arraylike

from dask_expr._util import _BackendData, _tokenize_deterministic, _tokenize_partial
from dask_expr._util import _BackendData, _tokenize_deterministic


class Expr:
Expand Down Expand Up @@ -245,25 +247,79 @@ def rewrite(self, kind: str):

return expr

def simplify(self):
def simplify_once(self, dependents: defaultdict):
"""Simplify an expression

This leverages the ``._simplify_down`` and ``._simplify_up``
methods defined on each class

Parameters
----------

dependents: defaultdict[list]
The dependents for every node.

Returns
-------
expr:
output expression
changed:
whether or not any change occured
"""
return self.rewrite(kind="simplify")
expr = self

while True:
out = expr._simplify_down()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out

# Allow children to simplify their parents
for child in expr.dependencies():
out = child._simplify_up(expr, dependents)
if out is None:
out = expr

if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
break

# Rewrite all of the children
new_operands = []
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
new = operand.simplify_once(dependents=dependents)
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)

if changed:
expr = type(expr)(*new_operands)

break

return expr

def simplify(self) -> Expr:
expr = self
while True:
dependents = collect_depdendents(expr)
new = expr.simplify_once(dependents=dependents)
if new._name == expr._name:
break
expr = new
return expr

def _simplify_down(self):
return

def _simplify_up(self, parent):
def _simplify_up(self, parent, dependents):
return

def lower_once(self):
Expand Down Expand Up @@ -322,42 +378,6 @@ def lower_completely(self) -> Expr:
def _lower(self):
return

def _remove_operations(self, frame, remove_ops, skip_ops=None):
"""Searches for operations that we have to push up again to avoid
the duplication of branches that are doing the same.

Parameters
----------
frame: Expression that we will search.
remove_ops: Ops that we will remove to push up again.
skip_ops: Ops that were introduced and that we want to ignore.

Returns
-------
tuple of the new expression and the operations that we removed.
"""

operations, ops_to_push_up = [], []
frame_base = frame
combined_ops = remove_ops if skip_ops is None else remove_ops + skip_ops
while isinstance(frame, combined_ops):
# Have to respect ops that were injected while lowering or filters
if isinstance(frame, remove_ops):
ops_to_push_up.append(frame.operands[1])
frame = frame.frame
break
else:
operations.append((type(frame), frame.operands[1:]))
frame = frame.frame

if len(ops_to_push_up) > 0:
# Remove the projections but build the remaining things back up
for op_type, operands in reversed(operations):
frame = op_type(frame, *operands)
return frame, ops_to_push_up
else:
return frame_base, []

@functools.cached_property
def _name(self):
return (
Expand Down Expand Up @@ -509,21 +529,6 @@ def substitute_parameters(self, substitutions: dict) -> Expr:
return type(self)(*new_operands)
return self

def _find_similar_operations(self, root: Expr, ignore: list | None = None):
# Find operations with the same type and operands.
# Parameter keys specified by `ignore` will not be
# included in the operand comparison
alike = [
op for op in root.find_operations(type(self)) if op._name != self._name
]
if not alike:
# No other operations of the same type. Early return
return []

# Return subset of `alike` with the same "token"
token = _tokenize_partial(self, ignore)
return [item for item in alike if _tokenize_partial(item, ignore) == token]

def _node_label_args(self):
"""Operands to include in the node label by `visualize`"""
return self.dependencies()
Expand Down Expand Up @@ -668,3 +673,19 @@ def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
or issubclass(operation, Expr)
), "`operation` must be`Expr` subclass)"
return (expr for expr in self.walk() if isinstance(expr, operation))


def collect_depdendents(expr) -> defaultdict:
dependents = defaultdict(list)
stack = [expr]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)

for dep in node.dependencies():
stack.append(dep)
dependents[dep._name].append(weakref.ref(node))
return dependents
6 changes: 3 additions & 3 deletions dask_expr/_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dask.dataframe import methods
from dask.utils import M

from dask_expr._expr import Blockwise, Expr, Projection
from dask_expr._expr import Blockwise, Expr, Projection, plain_column_projection


class CumulativeAggregations(Expr):
Expand All @@ -27,9 +27,9 @@ def _lower(self):
chunks_last = TakeLast(chunks, self.skipna)
return CumulativeFinalize(chunks, chunks_last, self.aggregate_operation)

def _simplify_up(self, parent):
def _simplify_up(self, parent, dependents):
if isinstance(parent, Projection):
return type(self)(self.frame[parent.operand("columns")], *self.operands[1:])
return plain_column_projection(self, parent, dependents)


class CumulativeBlockwise(Blockwise):
Expand Down
Loading