diff --git a/src/rules.py b/src/rules.py index 0d69249..c0ab242 100644 --- a/src/rules.py +++ b/src/rules.py @@ -1,12 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Tuple, List, Optional, Set +from typing import Dict, Tuple, List, Optional, Set, Iterable import operator import os from pathlib import Path import functools +import networkx as nx import numpy as np import polars as pl from lark import Lark, Transformer, LarkError @@ -81,6 +82,11 @@ class RuleError(Exception): "filter_compare", } +VISUALIZE_FUNCTIONS = { + "cycle_steps", + "path_subunits", +} + # AST node definitions and Grammar definition @@ -131,8 +137,23 @@ class Steps(Expr): class Call(Expr): value: str args: Tuple[Expr, ...] + allow_visualize_functions: bool = False def validate(self): + if self.value in VISUALIZE_FUNCTIONS and self.allow_visualize_functions: + return + # match self.value: + # case "cycle_steps": + # if len(self.args) != 1: + # raise RuleError(f"cycle_steps(expr) expects exactly 1 argument; got {len(self.args)} with args {self.args}") + # case "path_subunits": + # if len(self.args) != 1: + # raise RuleError(f"path_subunits(expr) expects exactly 1 argument; got {len(self.args)} with args {self.args}") + if self.value in VISUALIZE_FUNCTIONS and not self.allow_visualize_functions: + raise RuleError(f"Visualize function {self.value} cannot be used in rules when allow_visualize_functions is False." + " Check you aren't using a visualization rule function like cycle_steps() or path_subunits() in your rules," + " when doing non-visualization rule evaluation such as traits.") + if self.value not in CALL_FUNCTIONS: raise RuleError(f"Unknown function: {self.value}") n_args = len(self.args) @@ -234,6 +255,10 @@ def _as_float(e: Expr) -> float: class ASTTransformer(Transformer): + def __init__(self, allow_visualize_functions=False): + self.allow_visualize_functions = allow_visualize_functions + super().__init__() + def simple_name(self, items): return Name(value=str(items[0]), db=None) @@ -267,7 +292,7 @@ def step_(self, items): def call(self, items): name = str(items[0]) args = tuple(items[1:]) - return Call(name, args) + return Call(name, args, allow_visualize_functions=self.allow_visualize_functions) def pipe_(self, items): left, right = items @@ -285,10 +310,13 @@ def to_list(x): class CompiledRules: rules: Dict[str, Expr] # top-level rules with macros expanded needed_features: Set[str] + features_by_rules: Dict[str, Set[str]] + trees_by_rules: Dict[str, nx.DiGraph] + df: pl.DataFrame @classmethod def from_rules(cls, *args, **kwargs) -> CompiledRules: - definitions, rules = load_rules(*args, **kwargs) + definitions, rules, lf = load_rules(*args, **kwargs) defs_expanded = { k: expand_macros(v, definitions) for k, v in definitions.items() @@ -297,12 +325,21 @@ def from_rules(cls, *args, **kwargs) -> CompiledRules: # Expand rules using expanded defs # we need to hit again in case defs is empty (no parent col) # and we still need to add needed features from rules - needed_features: Set[str] = set() + features_by_rules = {k: set() for k in rules.keys()} + trees_by_rules = {k: nx.DiGraph() for k in rules.keys()} rules_expanded = { - k: expand_macros(v, defs_expanded, needed_features=needed_features) + k: expand_macros(v, defs_expanded, needed_features=features_by_rules[k], graph=trees_by_rules[k]) for k, v in rules.items() } - return cls(rules=rules_expanded, needed_features=needed_features) + needed_features = set().union(*[s for s in features_by_rules.values()]) + + return cls( + rules=rules_expanded, + needed_features=needed_features, + features_by_rules=features_by_rules, + trees_by_rules=trees_by_rules, + df=lf.collect() + ) def load_rules( @@ -311,6 +348,7 @@ def load_rules( label_col: str = "name", parent_col: str = "parent", rules_col: str = "child", + allow_visualize_functions: bool = False, ) -> Tuple[Dict[str, Expr], Dict[str, Expr]]: """ Assumes TSV has columns at least: name, parent, child @@ -340,7 +378,7 @@ def load_rules( has_parent_col = parent_col and (parent_col in cols) with open(Path(__file__).parent.absolute() / "rules.lark") as f: - parser = Lark(f, parser="lalr", transformer=ASTTransformer()) + parser = Lark(f, parser="lalr", transformer=ASTTransformer(allow_visualize_functions=allow_visualize_functions)) def parse_rule_expr(expr_str: str) -> Expr: """We use a closure to capture the parser instance since @@ -381,6 +419,7 @@ def parse_rule_expr(expr_str: str) -> Expr: .iter_rows() } + definitions = {} if has_parent_col: definitions = { a: b @@ -389,14 +428,12 @@ def parse_rule_expr(expr_str: str) -> Expr: .collect() .iter_rows() } - else: - definitions = {} - return definitions, rules + return definitions, rules, lf def expand_macros( - expr: Expr, definitions: Dict[str, Expr], needed_features: Set[str] = None + expr: Expr, definitions: Dict[str, Expr], needed_features: Set[str] = None, graph: Optional[nx.DiGraph] = None ) -> Expr: """Expand recursively macros in expr using definitions""" memo: Dict[Expr, Expr] = {} @@ -405,33 +442,81 @@ def expand_macros( if needed_features is None: skip_needed = True - def recurse(e: Expr, add_name_to_needed: bool = True) -> Expr: - if e in memo: - return memo[e] + def recurse(expr: Expr, add_name_to_needed: bool = True) -> Expr: + if expr in memo: + return memo[expr] + + entries, exits = None, None - if isinstance(e, Name): - name = e.value + if isinstance(expr, Name): + name = expr.value if name in definitions: if name in stack: raise RuleError(f"Cycle detected: {' -> '.join(stack + [name])}") stack.append(name) - out = recurse(definitions[name]) + out, _, _ = recurse(definitions[name]) stack.pop() else: - out = e + out = expr # This will slightly if add_name_to_needed and not skip_needed: needed_features.add(name) - elif isinstance(e, (Number, String)): - out = e - elif isinstance(e, And): - out = And(parts=tuple(recurse(part) for part in e.parts)) - elif isinstance(e, Or): - out = Or(parts=tuple(recurse(part) for part in e.parts)) - elif isinstance(e, Steps): - out = Steps(tuple(recurse(p) for p in e.parts)) - elif isinstance(e, Call): - fn, args = e.value, e.args + entries, exits = {expr}, {expr} + elif isinstance(expr, (Number, String)): + out = expr + elif isinstance(expr, And): + is_atomic = all(not isinstance(p, Steps) for p in expr.parts) + and_entries, and_exits = set(), set() + if is_atomic: + and_entries, and_exits = {expr}, {expr} + outs = [] + for part in expr.parts: + o, entries, exits = recurse(part) + if not is_atomic: + and_entries |= set(entries) + and_exits |= set(exits) + outs.append(o) + entries, exits = and_entries, and_exits + + out = And(parts=tuple(outs)) + elif isinstance(expr, Or): + is_atomic = all(not isinstance(p, Steps) for p in expr.parts) + or_entries, or_exits = set(), set() + if is_atomic: + or_entries, or_exits = {expr}, {expr} + outs = [] + for part in expr.parts: + o, entries, exits = recurse(part) + if not is_atomic: + or_entries |= set(entries) + or_exits |= set(exits) + outs.append(o) + entries, exits = or_entries, or_exits + + out = Or(parts=tuple(outs)) + elif isinstance(expr, Steps): + prev_exits = None + all_entries = None + all_exits = None + parts = [] + for part in expr.parts: + out, entries, exits = recurse(part) + parts.append(out) + if all_entries is None: + all_entries = set(entries) + + if graph is not None and prev_exits is not None: + # connect every previous exit to every current entry + for u in prev_exits: + for v in entries: + graph.add_edge(u, v, kind="seq") + + prev_exits = set(exits) + all_exits = set(exits) + entries, exits = all_entries, all_exits + out = Steps(tuple(parts)) + elif isinstance(expr, Call): + fn, args = expr.value, expr.args call_rec = [] # Determine which args should have their Names added to needed_features # since some call args are just numbers, ops, and columns. @@ -453,22 +538,28 @@ def recurse(e: Expr, add_name_to_needed: bool = True) -> Expr: # a gene id. So almost always harmless. else: add_name = True - call_rec.append(recurse(arg, add_name_to_needed=add_name)) - - out = Call(e.value, tuple(call_rec)) - elif isinstance(e, PipeChain): - calls = list(e.calls) - for i, call in enumerate(e.calls): - calls[i] = recurse(call) + o, _, _= recurse(arg, add_name_to_needed=add_name) + call_rec.append(o) + out = Call(expr.value, tuple(call_rec), allow_visualize_functions=expr.allow_visualize_functions) + elif isinstance(expr, PipeChain): + calls = list(expr.calls) + for i, call in enumerate(expr.calls): + o, _, _= recurse(call) + calls[i] = o out = PipeChain(calls=tuple(calls)) else: - raise TypeError(e) + raise TypeError(expr) # Some nodes need validation and have to be done after children are expanded - out.validate() - memo[e] = out - return out - - return recurse(expr) + try: + out.validate() + except Exception as e: + print(f"Error validating expression after macro expansion. Expression: {out}") + print(expr) + raise + ret = out, entries or set(), exits or set() + memo[expr] = ret + return ret + return recurse(expr)[0] def build_present_map( @@ -476,18 +567,20 @@ def build_present_map( sample_col: str, besthit_cols: List[str], needed_features: Set[str], + additional_cols: List[str] = None, ) -> Tuple[List[str], Dict[str, np.ndarray]]: """Build present_map of needed gene_ids from annotations DataFrame""" + additional_cols = additional_cols or [] besthit_cols = [col for col in besthit_cols if col in lf.columns] for col in besthit_cols: lf = lf.with_columns(ID_EXPR_DICT[col].alias(col)).explode(col) - lf = lf.select([sample_col] + besthit_cols) + lf = lf.select([sample_col] + besthit_cols + additional_cols) # unpivot to long (sample, hit) hit_col = "hit" lf = ( lf.unpivot( - index=sample_col, + index=[sample_col] + additional_cols, on=besthit_cols, variable_name="db", value_name=hit_col, @@ -517,7 +610,8 @@ def build_present_map( sub_s = sub.select(sample_col).unique().to_series().to_list() for s in sub_s: arr[sample_index[s]] = True - + if additional_cols: + return samples, present_map, lf return samples, present_map @@ -533,8 +627,6 @@ def __init__( self.present_map = present_map self.annotations = annotations self.sample_col = sample_col - self._memo: Dict[Expr, np.ndarray] = {} - self._memo_list: Dict[Expr, np.ndarray] = {} self._order_df = pl.DataFrame( { self.sample_col: self.samples, @@ -544,33 +636,25 @@ def __init__( self._all_false = np.zeros(len(samples), dtype=bool) - def eval_bool(self, expr: Expr) -> np.ndarray: - if expr in self._memo: - return self._memo[expr] - + @functools.cache + def eval_bool(self, expr: Expr, reduce_outer_and=True) -> np.ndarray: + out = None if isinstance(expr, Name): out = self.present_map.get(expr.value, self._all_false) - self._memo[expr] = out - return out if isinstance(expr, And): - out = np.all( - np.stack([self.eval_bool(part) for part in expr.parts], axis=0), axis=0 - ) - self._memo[expr] = out - return out + out = np.stack([self.eval_bool(part) for part in expr.parts], axis=1) + if reduce_outer_and: + out = np.all(out, axis=1) if isinstance(expr, Or): try: out = np.any( - np.stack([self.eval_bool(part) for part in expr.parts], axis=0), - axis=0, + np.stack([self.eval_bool(part) for part in expr.parts], axis=1), + axis=1, ) except ValueError as e: - print(f"Error evaluating OR expression: {expr}") - raise - self._memo[expr] = out - return out + raise RuleError(f"Error evaluating OR expression: {expr}") from e if isinstance(expr, PipeChain): out_or_df = self.annotations @@ -581,44 +665,45 @@ def eval_bool(self, expr: Expr) -> np.ndarray: kwargs["df"] = out if isinstance(out, pl.Expr): kwargs["masks"].append(out) - self._memo[expr] = out - return out if isinstance(expr, Call): out = self.eval_call(expr) - self._memo[expr] = out - return out if isinstance(expr, Number): # numeric alone isn't boolean; treat as error raise RuleError(f"Literal used where boolean expected: {expr}") if isinstance(expr, Steps): - raise RuleError( - "Step expressions (expression seperated by commas)" - " logic cannot be evaluted on their own. They must be used" - " in a supporting function such as `percent`. Offending" - f" rule: {expr}" - ) + out = self.eval_cycle(expr) + + if out is not None: + return out raise RuleError( "Something failed to parse properly." f" Tried to evaluate truth value of {expr}" ) - def eval_cycle(self, expr: Steps) -> np.ndarray: - if expr in self._memo_list: - return self._memo_list[expr] - - parts = [self.eval_bool(p) for p in expr.parts] + @functools.cache + def eval_cycle(self, expr: Steps | list[Expr], simplify=True, **kwargs) -> np.ndarray: + if isinstance(expr, Iterable): + parts = [self.eval_bool(p, **kwargs) for p in expr] + elif isinstance(expr, Steps): + parts = [self.eval_bool(p, **kwargs) for p in expr.parts] + else: + raise RuleError(f"Expected Steps or Iterable of Expr for cycle evaluation, got {expr}") mat = ( np.stack(parts, axis=1) if parts else np.zeros((len(self.samples), 0), dtype=bool) ) - self._memo_list[expr] = mat + if simplify: + mat = mat.any(axis=1) return mat + def len_cycle(self, expr: Steps) -> int: + return len(expr.parts) + def eval_call( self, call: Call, df: pl.DataFrame = None, masks: list[pl.Expr] = None ) -> np.ndarray: @@ -636,9 +721,9 @@ def eval_call( return self.not_(masks=kwargs["masks"]) return self.not_(self.eval_bool(args[0])) case "percent": - return self.percent(_as_int(args[0]), self.eval_cycle(args[1])) + return self.percent(_as_int(args[0]), self.eval_cycle(args[1], simplify=False)) case "at_least": - return self.at_least(_as_int(args[0]), self.eval_cycle(args[1])) + return self.at_least(_as_int(args[0]), self.eval_cycle(args[1], simplify=False)) case "column_count_values": return self.column_count_values( col=_as_str(args[0]), @@ -673,6 +758,11 @@ def eval_call( thr=_as_float(args[2]), **kwargs, ) + case "cycle_steps": + return self.cycle_steps(args) + + case "path_subunits": + return self.path_subunits(args) case _: raise RuleError(f"Unable to parse function in rules. Function: {fn}") @@ -748,7 +838,7 @@ def column_count_values( df = self.annotations if df is None else df if col not in df.columns: - raise RuleError(f"Missing column '{col}' for column_count_values()") + return np.zeros(len(self.samples), dtype=bool) if val_op not in OP_TO_EXPR: raise ValueError( f"Unsupported value operation={val_op!r} to compare values for column {col}. Use one of {sorted(OP_TO_EXPR)}" @@ -782,7 +872,7 @@ def column_sum_values( df = self.annotations if df is None else df if col not in df.columns: - raise RuleError(f"Missing column '{col}' for column_sum_values()") + return np.zeros(len(self.samples), dtype=bool) try: cmp_fn = OP_TO_EXPR[op] except KeyError: @@ -799,7 +889,7 @@ def filter_contains( ) -> pl.DataFrame: df = self.annotations if df is None else df if col not in df: - raise RuleError(f"Missing column '{col}' for filter_contains()") + return np.zeros(len(self.samples), dtype=bool) return pl.col(col).str.contains(val) def filter_compare( @@ -814,6 +904,31 @@ def filter_compare( raise ValueError(f"Unsupported op={op!r}. Use one of {sorted(OP_TO_EXPR)}") return cmp_fn(pl.col(col).cast(float), thr) + def cycle_steps(self, expr: Steps | list[Expr], **kwargs) -> pl.DataFrame: + cycle = self.eval_cycle(expr, simplify=False, **kwargs) + + df = pl.DataFrame( + { + "steps": cycle.shape[-1], + # sum across all axis except the first one. We don't know the number of dim + # since path_subunits doesn't reduce the and + "steps_present": cycle.sum(axis=tuple(i for i in range(len(cycle.shape)) if i > 0)), + "genome": self.samples, + } + ) + df = df.with_columns( + coverage_percentage = pl.col("steps_present") / pl.col("steps") + ) + return df + + def path_subunits(self, expr: Steps | list[Expr]) -> pl.DataFrame: + df = self.cycle_steps(expr, reduce_outer_and=False) + df = df.rename( + {"steps": "subunits", + "steps_present": "subunits_present" + }) + return df + def evaluate_rules( compiled: CompiledRules, @@ -842,6 +957,47 @@ def evaluate_rules( return df +def evaluate_cycles( + compiled: CompiledRules, + samples: List[str], + present_map: Dict[str, np.ndarray], + group_col: Optional[str] = "group", + label_col: Optional[str] = "name", + annotations: Optional[pl.DataFrame] = None, + sample_col: Optional[str] = None, + additional_cols: Optional[List[str]] = None, +) -> pl.DataFrame: + ev = Evaluator( + samples=samples, + present_map=present_map, + sample_col=sample_col, + annotations=annotations, + ) + # if group_col in set(compiled.lf.collect_schema().names()): + dfs = {} + additional_cols = [pl.col(c) for c in additional_cols] if additional_cols else [] + for group, frame in compiled.df.group_by(group_col, maintain_order=True): + group = group[0] + dfs[group] = [] + for rn in frame.select(pl.col(label_col)).to_series(): + if rn is None: + continue + # for rn, expr in compiled.rules.items(): + expr = compiled.rules[rn] + out = ev.eval_bool(expr) + if isinstance(out, np.ndarray): + out = pl.DataFrame(dict(present=out, genome=ev.samples)) + out = out.with_columns(pl.lit(rn).alias(label_col)) + dfs[group].append(out) + + assert all(isinstance(df, pl.DataFrame) for df in dfs[group]), f"All rules in group {group} should evaluate to the same type, but got different types: {[type(df) for df in dfs[group]]}" + df = pl.concat(dfs[group]) + df = df.join(compiled.df.select(pl.col(label_col), pl.col(group_col), *additional_cols), on=label_col) + dfs[group] = df + + return dfs + + def evaluate_rules_on_anno( annotations_path: os.PathLike = None, annotations: pl.DataFrame = None,