diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index fa58308faced..498f77a3f7d5 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -65,15 +65,6 @@ TVM_DLL tvm::runtime::Map MatchGraph(const PatternContext& ctx, Optional start_hint = NullOpt, bool must_include_hint = false); -/** - * \brief Match a graph-wise pattern with the current context (PatternContext::Current()). - */ -inline tvm::runtime::Map MatchGraphDefault(const DataflowBlock& dfb, - Optional start_hint = NullOpt, - bool must_include_hint = false) { - return MatchGraph(PatternContext::Current(), dfb, start_hint, must_include_hint); -} - } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 144a7f45bf57..e4c27f3558ba 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -245,15 +245,15 @@ class PatternContext : public ObjectRef { } } - /*! \brief Get the pass context object on the top of the stack */ - TVM_DLL static PatternContext Current(); + /*! \brief Get the constraint context object on the top of the stack */ + TVM_DLL static Optional Current(); class Internal; private: - /*! \brief The RAII-like entry of a pass context scope */ + /*! \brief The RAII-like entry of a constraint context scope */ TVM_DLL void EnterWithScope(); - /*! \brief The RAII-like exit of a pass context scope */ + /*! \brief The RAII-like exit of a constraint context scope */ TVM_DLL void ExitWithScope(); friend class Internal; friend class With; diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 248e9577264b..acabac2dcbf1 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -52,7 +52,7 @@ def register_df_node(type_key=None): class DFPattern(Node): """Base class of all Patterns.""" - def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern": + def __call__(self, *args, varg_default_wildcard=False, add_constraint=True) -> "CallPattern": """ Syntax sugar for creating a CallPattern with argument patterns @@ -61,7 +61,7 @@ def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern": result: CallPattern The resulting CallPattern """ - return CallPattern(self, args, varg_default_wildcard) + return CallPattern(self, args, varg_default_wildcard, add_constraint) def __or__(self, other: "DFPattern") -> "OrPattern": """ @@ -387,6 +387,9 @@ class CallPattern(DFPattern): varg_default_wildcard: bool If True, args can be fewer than actual provided arguments. + add_constraint: bool + If True, automatically add "used-by" constraints between caller and callee expressions. + Note ---- By setting varg_default_wildcard to True, we can only focus on the argument @@ -400,11 +403,16 @@ def __init__( op: "DFPattern", args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]], varg_default_wildcard: bool = False, + add_constraint=True, ): self.__init_handle_by_constructor__( ffi.CallPattern, op, args, varg_default_wildcard # type: ignore ) + if add_constraint: + for i, arg in enumerate(args): + arg.used_by(self, i) + @register_df_node class FunctionPattern(DFPattern): @@ -835,7 +843,7 @@ def _is_call_tir( elif isinstance(args, (list, tuple)): args = TuplePattern(args) - return is_op("relax.call_tir")(func_pattern, args) + return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) # Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo @@ -871,7 +879,7 @@ def _is_call_dps_packed( elif isinstance(args, (list, tuple)): args = TuplePattern(args) - return is_op("relax.call_dps_packed")(func_pattern, args) + return is_op("relax.call_dps_packed")(func_pattern, args, add_constraint=False) def is_call_dps_packed( @@ -915,7 +923,7 @@ def is_call_packed( The resulting CallPattern """ if args is None: - return ExternFuncPattern(func_name)(varg_default_wildcard=True) + return ExternFuncPattern(func_name)(varg_default_wildcard=True, add_constraint=False) return ExternFuncPattern(func_name)(*args) diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 5eb1bf3ea6f6..5580f6a1ab74 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -394,8 +394,8 @@ std::stack& pattern_ctx_stack() { return graph_pattern_managers; } -PatternContext PatternContext::Current() { - ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found."; +Optional PatternContext::Current() { + if (pattern_ctx_stack().empty()) return NullOpt; return pattern_ctx_stack().top(); } @@ -419,7 +419,9 @@ void PatternContext::ExitWithScope() { } static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, PairCons pcon) { - PatternContext::Current().add_constraint(lhs, rhs, pcon); + if (auto ctx = PatternContext::Current()) { + ctx.value().add_constraint(lhs, rhs, pcon); + } } TVM_REGISTER_NODE_TYPE(PatternSeqNode); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 9679e14fffe7..76bce47f7ff7 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1034,15 +1034,6 @@ def main( matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) - # TODO(masahi): Automate addition of used_by constraints during is_op - inp_pat.used_by(matmul1, 0) - inp_pat.used_by(matmul2, 0) - inp_pat.used_by(matmul3, 0) - - Q_weight_pat.only_used_by(matmul1, 1) - K_weight_pat.only_used_by(matmul2, 1) - V_weight_pat.only_used_by(matmul3, 1) - dfb = QKV_proj["main"].body.blocks[0] out = ctx.match_dfb(dfb)