Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions include/tvm/relax/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ TVM_DLL tvm::runtime::Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
Optional<Var> start_hint = NullOpt,
bool must_include_hint = false);

/**
* \brief Match a graph-wise pattern with the current context (PatternContext::Current()).
*/
inline tvm::runtime::Map<DFPattern, Var> MatchGraphDefault(const DataflowBlock& dfb,
Optional<Var> start_hint = NullOpt,
bool must_include_hint = false) {
return MatchGraph(PatternContext::Current(), dfb, start_hint, must_include_hint);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets removed since it is not used anywhere but it touches PatternContext::Current() (which is now Optional)

}

} // namespace relax
} // namespace tvm

Expand Down
8 changes: 4 additions & 4 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,15 @@ class PatternContext : public ObjectRef {
}
}

/*! \brief Get the pass context object on the top of the stack */
TVM_DLL static PatternContext Current();
/*! \brief Get the constraint context object on the top of the stack */
TVM_DLL static Optional<PatternContext> Current();

class Internal;

private:
/*! \brief The RAII-like entry of a pass context scope */
/*! \brief The RAII-like entry of a constraint context scope */
TVM_DLL void EnterWithScope();
/*! \brief The RAII-like exit of a pass context scope */
/*! \brief The RAII-like exit of a constraint context scope */
TVM_DLL void ExitWithScope();
friend class Internal;
friend class With<PatternContext>;
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def register_df_node(type_key=None):
class DFPattern(Node):
"""Base class of all Patterns."""

def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern":
def __call__(self, *args, varg_default_wildcard=False, add_constraint=True) -> "CallPattern":
"""
Syntax sugar for creating a CallPattern with argument patterns

Expand All @@ -61,7 +61,7 @@ def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern":
result: CallPattern
The resulting CallPattern
"""
return CallPattern(self, args, varg_default_wildcard)
return CallPattern(self, args, varg_default_wildcard, add_constraint)

def __or__(self, other: "DFPattern") -> "OrPattern":
"""
Expand Down Expand Up @@ -387,6 +387,9 @@ class CallPattern(DFPattern):
varg_default_wildcard: bool
If True, args can be fewer than actual provided arguments.

add_constraint: bool
If True, automatically add "used-by" constraints between caller and callee expressions.

Note
----
By setting varg_default_wildcard to True, we can only focus on the argument
Expand All @@ -400,11 +403,16 @@ def __init__(
op: "DFPattern",
args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]],
varg_default_wildcard: bool = False,
add_constraint=True,
):
self.__init_handle_by_constructor__(
ffi.CallPattern, op, args, varg_default_wildcard # type: ignore
)

if add_constraint:
for i, arg in enumerate(args):
arg.used_by(self, i)


@register_df_node
class FunctionPattern(DFPattern):
Expand Down Expand Up @@ -835,7 +843,7 @@ def _is_call_tir(
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_tir")(func_pattern, args)
return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)


# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
Expand Down Expand Up @@ -871,7 +879,7 @@ def _is_call_dps_packed(
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_dps_packed")(func_pattern, args)
return is_op("relax.call_dps_packed")(func_pattern, args, add_constraint=False)


def is_call_dps_packed(
Expand Down Expand Up @@ -915,7 +923,7 @@ def is_call_packed(
The resulting CallPattern
"""
if args is None:
return ExternFuncPattern(func_name)(varg_default_wildcard=True)
return ExternFuncPattern(func_name)(varg_default_wildcard=True, add_constraint=False)
return ExternFuncPattern(func_name)(*args)


Expand Down
8 changes: 5 additions & 3 deletions src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ std::stack<PatternContext>& pattern_ctx_stack() {
return graph_pattern_managers;
}

PatternContext PatternContext::Current() {
ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found.";
Optional<PatternContext> PatternContext::Current() {
if (pattern_ctx_stack().empty()) return NullOpt;
return pattern_ctx_stack().top();
}

Expand All @@ -419,7 +419,9 @@ void PatternContext::ExitWithScope() {
}

static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, PairCons pcon) {
PatternContext::Current().add_constraint(lhs, rhs, pcon);
if (auto ctx = PatternContext::Current()) {
ctx.value().add_constraint(lhs, rhs, pcon);
}
}

TVM_REGISTER_NODE_TYPE(PatternSeqNode);
Expand Down
9 changes: 0 additions & 9 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,15 +1034,6 @@ def main(
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

# TODO(masahi): Automate addition of used_by constraints during is_op
inp_pat.used_by(matmul1, 0)
inp_pat.used_by(matmul2, 0)
inp_pat.used_by(matmul3, 0)

Q_weight_pat.only_used_by(matmul1, 1)
K_weight_pat.only_used_by(matmul2, 1)
V_weight_pat.only_used_by(matmul3, 1)
Comment on lines -1042 to -1044
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand used_by can be removed now as it is implied by a function call. However, only_used_by is a stronger constraint that cannot be implied via a fn call right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yeah I guess you meant oftentimes we don't fold Q, K, and V so it is fine if they are being used by others at the same time.


dfb = QKV_proj["main"].body.blocks[0]
out = ctx.match_dfb(dfb)

Expand Down