Skip to content
Merged
9 changes: 1 addition & 8 deletions include/tvm/relax/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,12 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(

/**
* \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping.
* \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the
* starting point of the matching so that we can distinguish multiple matches.
*
* \param ctx The graph-wise patterns.
* \param dfb The function to match.
* \param start_hint The starting point expression to match to distinguish multiple matches.
* \param must_include_hint If start_hint is given, the return pattern must include start_hint.
* \return Matched patterns and corresponding bound variables
*/
TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
const DataflowBlock& dfb,
Optional<Var> start_hint = NullOpt,
bool must_include_hint = false);
const DataflowBlock& dfb);

} // namespace relax
} // namespace tvm
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/relax/dpl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""The Graph Matching Context Manager for Dataflow Pattern Language."""

from typing import Optional, Dict
from typing import Dict

import tvm
from ..expr import DataflowBlock, Var
Expand Down Expand Up @@ -63,8 +63,6 @@ def current() -> "PatternContext":
def match_dfb(
self,
dfb: DataflowBlock,
start_hint: Optional[Var] = None,
must_include_hint: bool = False,
) -> Dict[DFPattern, Var]:
"""
Match a DataflowBlock via a graph of DFPattern and corresponding constraints
Expand All @@ -73,14 +71,10 @@ def match_dfb(
----------
dfb : DataflowBlock
The DataflowBlock to match
start_hint : Optional[Var], optional
Indicating the starting expression to match, by default None
must_include_hint : bool, optional
Whether the start_hint expression must be matched, by default False

Returns
-------
Dict[DFPattern, Var]
The mapping from DFPattern to matched expression
"""
return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore
return ffi.match_dfb(self, dfb) # type: ignore
Loading