From 46f4e550b743986c1e2557fd6fd7e2a21a04fe67 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 9 Mar 2020 13:45:17 -0700 Subject: [PATCH 01/46] Add initial pattern ast scaffolding --- include/tvm/relay/dataflow_pattern.h | 83 +++++++++++++++++++++++++ python/tvm/relay/df_pattern/__init__.py | 35 +++++++++++ python/tvm/relay/df_pattern/_ffi.py | 20 ++++++ src/relay/ir/dataflow_pattern.cc | 52 ++++++++++++++++ 4 files changed, 190 insertions(+) create mode 100644 include/tvm/relay/dataflow_pattern.h create mode 100644 python/tvm/relay/df_pattern/__init__.py create mode 100644 python/tvm/relay/df_pattern/_ffi.py create mode 100644 src/relay/ir/dataflow_pattern.cc diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h new file mode 100644 index 000000000000..f973298f7665 --- /dev/null +++ b/include/tvm/relay/dataflow_pattern.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a literal expression. + * + * \note Uses structural equality on expressions to check equality. + * + */ +class ExprPattern; +/*! + * \brief Constant tensor type. + */ +class ExprPatternNode : public DFPatternNode { + public: + /*! \brief The expression to match. */ + Expr expr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("expr", &expr); + } + + static constexpr const char* _type_key = "relay.pattern.Expr"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +class ExprPattern : public DFPattern { + public: + TVM_DLL ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py new file mode 100644 index 000000000000..6d241c690a9f --- /dev/null +++ b/python/tvm/relay/df_pattern/__init__.py @@ -0,0 +1,35 @@ +from ....base import Node +from . import _ffi as ffi + +def register_df_node(type_key=None): + """Register a Relay node type. + + Parameters + ---------- + type_key : str or cls + The type key of the node. + """ + if not isinstance(type_key, str): + return tvm._ffi.register_object( + "relay.df_pattern" + type_key.__name__)(type_key) + return tvm._ffi.register_object(type_key) + +class DFPattern(Node): + """Base class of all primitive expressions. + + PrimExpr is used in the low-level code + optimizations and integer analysis. + """ + pass + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches a constant expression. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + """ + def __init__(self, expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) diff --git a/python/tvm/relay/df_pattern/_ffi.py b/python/tvm/relay/df_pattern/_ffi.py new file mode 100644 index 000000000000..2049f4217efb --- /dev/null +++ b/python/tvm/relay/df_pattern/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relay.df_pattern", __name__) diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc new file mode 100644 index 000000000000..e7f0d8dd9746 --- /dev/null +++ b/src/relay/ir/dataflow_pattern.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_pattern.cc + * \brief The dataflow pattern language for Relay. + */ +#include + +namespace tvm { +namespace relay { + +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.ExprPattern") +.set_body_typed([](Expr e) { + return ExprPattern(e); + }); + +// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { +// auto* node = static_cast(ref.get()); +// const PackedFunc* fprint = Registry::Get("relay._constant_repr"); +// CHECK(fprint) << "unable to find printing function for constants"; +// std::string data = (*fprint)(GetRef(node)); +// p->stream << "Constant(" << data << ")"; +// }); + +} // namespace relay +} // namespace tvm From f570bf0f57840434d6ef26f3584aa393eadfc001 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 9 Mar 2020 13:57:47 -0700 Subject: [PATCH 02/46] Add pattern node test case --- include/tvm/relay/dataflow_pattern.h | 2 +- python/tvm/relay/df_pattern/__init__.py | 9 +++++---- tests/python/relay/test_df_pattern.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) create mode 100644 tests/python/relay/test_df_pattern.py diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index f973298f7665..adb02f3c50aa 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -68,7 +68,7 @@ class ExprPatternNode : public DFPatternNode { v->Visit("expr", &expr); } - static constexpr const char* _type_key = "relay.pattern.Expr"; + static constexpr const char* _type_key = "relay.df_pattern.ExprPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); }; diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 6d241c690a9f..2ec88cfbce41 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -1,4 +1,5 @@ -from ....base import Node +from ...ir.base import Node +from ... import _ffi as tvm_ffi from . import _ffi as ffi def register_df_node(type_key=None): @@ -10,9 +11,9 @@ def register_df_node(type_key=None): The type key of the node. """ if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay.df_pattern" + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) + return tvm_ffi.register_object( + "relay.df_pattern." + type_key.__name__)(type_key) + return tvm_ffi.register_object(type_key) class DFPattern(Node): """Base class of all primitive expressions. diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py new file mode 100644 index 000000000000..9b11fe748bad --- /dev/null +++ b/tests/python/relay/test_df_pattern.py @@ -0,0 +1,10 @@ +import tvm +from tvm import relay +from tvm.relay.df_pattern import ExprPattern + + +def test_expr_pattern(): + ep = ExprPattern(relay.var('x', shape=(4, 1))) + print(ep) + +test_expr_pattern() From 9827c32c65a42ca70b6a28589f209ee62badd83f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 11 Mar 2020 13:25:42 -0700 Subject: [PATCH 03/46] Pattern Language --- include/tvm/relay/dataflow_pattern.h | 273 +++++++++++++++++++++++- python/tvm/relay/df_pattern/__init__.py | 178 ++++++++++++++- src/relay/ir/dataflow_pattern.cc | 161 +++++++++++++- tests/python/relay/test_df_pattern.py | 57 ++++- 4 files changed, 650 insertions(+), 19 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index adb02f3c50aa..93bec950bede 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -50,14 +50,7 @@ class DFPattern : public ObjectRef { }; /*! - * \brief A pattern which matches a literal expression. - * - * \note Uses structural equality on expressions to check equality. - * - */ -class ExprPattern; -/*! - * \brief Constant tensor type. + * \brief Pattern for Relay Expression. */ class ExprPatternNode : public DFPatternNode { public: @@ -72,12 +65,276 @@ class ExprPatternNode : public DFPatternNode { TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); }; +/*! + * \brief A pattern which matches a literal expression. + * + * \note Uses structural equality on expressions to check equality. + * + */ class ExprPattern : public DFPattern { public: TVM_DLL ExprPattern(Expr expr); TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); }; + +/*! + * \brief A Pattern to Match a Relay Variable + */ +class VarPattern; +/*! \brief Container for Var */ +class VarPatternNode : public DFPatternNode { + public: + /*! + * \brief The name of the Var (optional). + */ + std::string name; + /*! + * \brief type annotaion of the variable. + * This field records user provided type annotation of the Var. + * This field is optional and can be None. + */ + Type type_annotation; + + /*! \return The name hint of the variable */ + const std::string& name_hint() const { + return name; + } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("type_annotation", &type_annotation); + } + + TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation); + + static constexpr const char* _type_key = "relay.df_pattern.VarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +class VarPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class CallPattern; +/*! \brief CallPattern container. */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("type_args", &type_args); + } + + TVM_DLL static CallPattern make(DFPattern op, Array args, Attrs attrs, + Array type_args); + + static constexpr const char* _type_key = "relay.df_pattern.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! \brief Tuple of multiple Exprs */ +class TuplePattern; +/*! \brief Tuple container */ +class TuplePatternNode : public DFPatternNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("fields", &fields); + } + + TVM_DLL static TuplePattern make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.df_pattern.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +class TuplePattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemPattern; +class TupleGetItemPatternNode : public DFPatternNode { + public: + /*! \brief The tuple Expression */ + DFPattern tuple; + /*! \brief which value to get */ + int index; + + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple_value", &tuple); + } + + TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index); + + static constexpr const char* _type_key = "relay.df_pattern.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +class TupleGetItemPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +class AltPattern; +/*! + * \brief Pattern for Alternate Expressions. + */ +class AltPatternNode : public DFPatternNode { + public: + /*! \brief The left optional pattern. */ + DFPattern left; + /*! \brief The right optional pattern. */ + DFPattern right; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + TVM_DLL static AltPattern make(DFPattern left, DFPattern right); + + static constexpr const char* _type_key = "relay.df_pattern.AltPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches either of two patterns + */ +class AltPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); +}; + + +/*! + * \brief Wildcard Pattern. + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.df_pattern.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches anything. + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +class TypePattern; +/*! + * \brief Pattern for Types. + */ +class TypePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + TVM_DLL static TypePattern make(DFPattern pattern, Type type); + + static constexpr const char* _type_key = "relay.df_pattern.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class TypePattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +class AttrPattern; +/*! + * \brief Pattern for Types. + */ +class AttrPatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The attribute to match */ + Attrs attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + TVM_DLL static AttrPattern make(DFPattern pattern, Attrs attrs); + + static constexpr const char* _type_key = "relay.df_pattern.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class AttrPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 2ec88cfbce41..05249f971d7c 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -1,5 +1,7 @@ from ...ir.base import Node +from ...ir import make_node from ... import _ffi as tvm_ffi +from ..op import get from . import _ffi as ffi def register_df_node(type_key=None): @@ -21,7 +23,15 @@ class DFPattern(Node): PrimExpr is used in the low-level code optimizations and integer analysis. """ - pass + def __call__(self, *args): + return CallPattern(self, list(args)) + + def __or__(self, other): + return AltPattern(self, other) + + def has_attr(self, attr_name, attr_value): + attrs = make_node("DictAttrs", **{attr_name:attr_value}) + return AttrPattern(self, attrs) @register_df_node class ExprPattern(DFPattern): @@ -34,3 +44,169 @@ class ExprPattern(DFPattern): """ def __init__(self, expr): self.__init_handle_by_constructor__(ffi.ExprPattern, expr) + +@register_df_node +class VarPattern(DFPattern): + """A local variable in Relay. + + Local variable can be used to declare input + arguments to a function, or intermediate variables. + + Parameters + ---------- + name_hint: str + The name of the variable. + This name only acts as a hint, and is not used + for equality. + + type_annotation: tvm.relay.Type, optional + The type annotation on the variable. + """ + def __init__(self, name_hint, type_annotation=None): + self.__init_handle_by_constructor__( + ffi.VarPattern, name_hint, type_annotation) + +# @property +# def name_hint(self): +# """Get name hint of the current var.""" +# name = self.name +# return name + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node in Relay. + + Parameters + ---------- + op: realy.df_pattern.DFPattern + The operation to be called. + + args: List[realy.df_pattern.DFPattern] + The arguments to the call. + + attrs: Optional[tvm.Attrs] + Attributes to the call, can be None + + type_args: Optional[List[tvm.relay.Type]] + The additional type arguments, this is only + used in advanced usecase of template functions. + """ + def __init__(self, op, args, attrs=None, type_args=None): + if not type_args: + type_args = [] + self.__init_handle_by_constructor__( + ffi.CallPattern, op, args, attrs, type_args) + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relay Tuple. + + Parameters + ---------- + fields : List[tvm.relay.df_pattern.DFPattern] + The fields in the tuple. + """ + def __init__(self, fields): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError("TuplePattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + def astype(self, _): + raise TypeError("astype cannot be used on TuplePattern") + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relay.df_pattern.DFPattern + The input tuple expression. + + index: int + The index. + """ + def __init__(self, tuple_value, index): + self.__init_handle_by_constructor__( + ffi.TupleGetItemPattern, tuple_value, index) + +@register_df_node +class AltPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relay.df_pattern.DFPattern + One possible matching Pattern + right: tvm.relay.df_pattern.DFPattern + One possible matching Pattern + """ + def __init__(self, tuple_value, index): + self.__init_handle_by_constructor__( + ffi.AltPattern, tuple_value, index) + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything. + """ + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) + +@register_df_node +class TypePattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The input tuple expression. + + ttype: tvm.relay.Type + The type to match + """ + def __init__(self, pattern, ttype): + self.__init_handle_by_constructor__( + ffi.TypePattern, pattern, ttype) + +@register_df_node +class AttrPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The input tuple expression. + + attrs: tvm.Attrs + The attributes to match + """ + def __init__(self, pattern, attrs): + self.__init_handle_by_constructor__( + ffi.AttrPattern, pattern, attrs) + +def is_input(name=None) -> DFPattern: + return VarPattern(name) + +def is_op(op_name: str) -> DFPattern: + op = get(op_name) + return ExprPattern(op) + +def wildcard() -> DFPattern: + return WildcardPattern() + +def has_type(ty, pattern=None): + if pattern is None: + pattern = wildcard() + return TypePattern(pattern, ty) + +def has_attr(attr_name, attr_value, pattern=None): + if pattern is None: + pattern = wildcard() + return patter.has_attr(attr_name, attr_value) diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index e7f0d8dd9746..5a34ae9d832d 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -39,14 +39,159 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.ExprPattern") return ExprPattern(e); }); -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast(ref.get()); -// const PackedFunc* fprint = Registry::Get("relay._constant_repr"); -// CHECK(fprint) << "unable to find printing function for constants"; -// std::string data = (*fprint)(GetRef(node)); -// p->stream << "Constant(" << data << ")"; -// }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->expr); + }); + + +VarPattern VarPatternNode::make(std::string name_hint, Type type_annotation) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + n->type_annotation = std::move(type_annotation); + return VarPattern(n); +} + +TVM_REGISTER_NODE_TYPE(VarPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.VarPattern") +.set_body_typed(static_cast(VarPatternNode::make)); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "VarPattern(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); + +CallPattern CallPatternNode::make(DFPattern op, Array args, Attrs attrs, + Array type_args) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + return CallPattern(n); +} + +TVM_REGISTER_NODE_TYPE(CallNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.CallPattern") +.set_body_typed(CallPatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs + << ", " << node->type_args << ")"; +}); + +TuplePattern TuplePatternNode::make(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + return TuplePattern(n); +} + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.TuplePattern") +.set_body_typed(TuplePatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TuplePattern(" << node->fields << ")"; + }); + +TupleGetItemPattern TupleGetItemPatternNode::make(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + return TupleGetItemPattern(n); +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.TupleGetItemPattern") +.set_body_typed(TupleGetItemPatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; +}); + +AltPattern AltPatternNode::make(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + return AltPattern(n); +} + +TVM_REGISTER_NODE_TYPE(AltPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.AltPattern") +.set_body_typed(AltPatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; + }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.WildcardPattern") +.set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "*"; + }); + +TypePattern TypePatternNode::make(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + return TypePattern(n); +} + +TVM_REGISTER_NODE_TYPE(TypePatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.TypePattern") +.set_body_typed(TypePatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; + }); + +AttrPattern AttrPatternNode::make(DFPattern pattern, Attrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + return AttrPattern(n); +} + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.AttrPattern") +.set_body_typed(AttrPatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; + }); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 9b11fe748bad..1344da25a1bc 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -1,10 +1,63 @@ import tvm from tvm import relay -from tvm.relay.df_pattern import ExprPattern +from tvm.relay.df_pattern import * def test_expr_pattern(): ep = ExprPattern(relay.var('x', shape=(4, 1))) print(ep) -test_expr_pattern() +def test_var_pattern(): + v = is_input("x") + print(v) + +def test_wildcard_pattern(): + wc = wildcard() + print(wc) + +def test_CallPattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("add")(wc1, wc2) + print(c) + +def test_TuplePattern(): + wc1 = wildcard() + wc2 = wildcard() + t = TuplePattern([wc1, wc2]) + print(t) + +def test_TupleGetItemPattern(): + wc1 = wildcard() + wc2 = wildcard() + t = TuplePattern([wc1, wc2]) + tgi = TupleGetItemPattern(t, 1) + print(tgi) + +def test_AltPattern(): + is_add_or_sub = is_op('add') | is_op('subtract') + print(is_add_or_sub) + +def test_TypePattern(): + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + print(ty_pat) + +# NB: 1 corresponds to the C++ enum that specicfies this +# we loose the type safety due to the Python/C++ calling +# convention. +K_ELEMWISE = 1 +def test_AttrPattern(): + op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + print(op_pat) + +if __name__ == "__main__": + test_expr_pattern() + test_var_pattern() + test_wildcard_pattern() + test_CallPattern() + test_TuplePattern() + test_TupleGetItemPattern() + test_AltPattern() + test_TypePattern() + test_AttrPattern() From ccd2ee9daf6cee2f759e9573c9b3773cb28a98b4 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 25 Mar 2020 15:30:16 -0700 Subject: [PATCH 04/46] Pattern Matcher --- include/tvm/relay/dataflow_matcher.h | 139 ++++++++++++ include/tvm/relay/dataflow_pattern.h | 1 + python/tvm/relay/df_pattern/__init__.py | 29 ++- src/relay/ir/dataflow_matcher.cc | 290 ++++++++++++++++++++++++ tests/python/relay/test_df_pattern.py | 207 +++++++++++++++-- 5 files changed, 650 insertions(+), 16 deletions(-) create mode 100644 include/tvm/relay/dataflow_matcher.h create mode 100644 src/relay/ir/dataflow_matcher.cc diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h new file mode 100644 index 000000000000..cef38f71cf9f --- /dev/null +++ b/include/tvm/relay/dataflow_matcher.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ +#define TVM_RELAY_DATAFLOW_MATCHER_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + CHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + return vtable; + } +}; + +class DFPatternMutator : public DFPatternFunctor { + public: + virtual DFPattern Mutate(const DFPattern& pattern); + DFPattern VisitDFPattern(const DFPattern& pattern) override; + DFPattern VisitDFPattern_(const AltPatternNode* op) override; + DFPattern VisitDFPattern_(const AttrPatternNode* op) override; + DFPattern VisitDFPattern_(const CallPatternNode* op) override; + DFPattern VisitDFPattern_(const ExprPatternNode* op) override; + DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override; + DFPattern VisitDFPattern_(const TuplePatternNode* op) override; + DFPattern VisitDFPattern_(const TypePatternNode* op) override; + DFPattern VisitDFPattern_(const VarPatternNode* op) override; + DFPattern VisitDFPattern_(const WildcardPatternNode* op) override; + + protected: + std::unordered_map memo_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 93bec950bede..beac285e3a4d 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -26,6 +26,7 @@ #include #include +#include namespace tvm { namespace relay { diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 05249f971d7c..14d89d08b07c 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -1,9 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relay Pattern Language and tooling.""" from ...ir.base import Node from ...ir import make_node from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi +def match(pattern, expr): + return ffi.match(pattern, expr) + def register_df_node(type_key=None): """Register a Relay node type. @@ -33,6 +53,9 @@ def has_attr(self, attr_name, attr_value): attrs = make_node("DictAttrs", **{attr_name:attr_value}) return AttrPattern(self, attrs) + def match(self, expr): + return match(self, expr) + @register_df_node class ExprPattern(DFPattern): """A pattern which matches a constant expression. @@ -191,7 +214,7 @@ def __init__(self, pattern, attrs): self.__init_handle_by_constructor__( ffi.AttrPattern, pattern, attrs) -def is_input(name=None) -> DFPattern: +def is_input(name="") -> DFPattern: return VarPattern(name) def is_op(op_name: str) -> DFPattern: @@ -201,10 +224,10 @@ def is_op(op_name: str) -> DFPattern: def wildcard() -> DFPattern: return WildcardPattern() -def has_type(ty, pattern=None): +def has_type(ttype, pattern=None): if pattern is None: pattern = wildcard() - return TypePattern(pattern, ty) + return TypePattern(pattern, ttype) def has_attr(attr_name, attr_value, pattern=None): if pattern is None: diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc new file mode 100644 index 000000000000..fc8f1a585500 --- /dev/null +++ b/src/relay/ir/dataflow_matcher.cc @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +class DFPatternMatcher : public DFPatternFunctor { + public: + bool Match(const DFPattern& pattern, const Expr& expr); + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + protected: + std::unordered_map memo_; +}; + +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + if (memo_.count(pattern)) { + return expr.same_as(memo_[pattern]); + } else { + auto out = VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern] = expr; + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { + return Match(op->left, expr) || Match(op->right, expr); +} +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { + bool matches = false; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + auto attributes = attr_pattern->attrs.as()->dict; + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + switch (op_map[op].type_code()) { + case kDLInt: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator double(); + } + break; + case kTVMStr: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator std::string(); + } + break; + default: + throw "Unsupported type"; + } + } + } + } + return matches; +} +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* call_node = expr.as()) { + if (op->args.size() == call_node->args.size()) { + matches = Match(op->op, call_node->op); + size_t i = 0; + while (matches && i < op->args.size()) { + matches &= Match(op->args[i], call_node->args[i]); + ++i; + } + } + } + return matches; +} +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { + return op->expr == expr; +} +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_get_item_node = expr.as()) { + matches = + (op->index == tuple_get_item_node->index) && Match(op->tuple, tuple_get_item_node->tuple); + } + return matches; +} +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + matches = true; + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= Match(op->fields[i], tuple_node->fields[i]); + ++i; + } + } + } + return matches; +} +Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && Match(op->pattern, expr); +} +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* var_node = expr.as()) { + matches = true; + if (op->name_hint() != "") { + matches &= op->name_hint() == var_node->name_hint(); + } + } + return matches; +} +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +// DFPatternMutator + +DFPattern DFPatternMutator::Mutate(const DFPattern& pattern) { return VisitDFPattern(pattern); } + +DFPattern DFPatternMutator::VisitDFPattern(const DFPattern& pattern) { + auto it = this->memo_.find(pattern); + if (it != this->memo_.end()) { + return it->second; + } else { + auto new_pattern = DFPatternFunctor::VisitDFPattern(pattern); + memo_[pattern] = new_pattern; + return new_pattern; + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const AltPatternNode* op) { + auto new_left = Mutate(op->left); + auto new_right = Mutate(op->right); + + if (new_left.same_as(op->left) && new_right.same_as(op->right)) { + return GetRef(op); + } else { + return AltPatternNode::make(new_left, new_right); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const AttrPatternNode* op) { + auto new_pattern = Mutate(op->pattern); + if (new_pattern.same_as(op->pattern)) { + return GetRef(op); + } else { + return AttrPatternNode::make(new_pattern, op->attrs); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const CallPatternNode* op) { + auto new_op = Mutate(op->op); + bool unchanged = op->op.same_as(new_op); + tvm::Array call_args; + for (auto arg : op->args) { + auto new_arg = Mutate(arg); + call_args.push_back(new_arg); + unchanged &= arg.same_as(new_arg); + } + if (unchanged) { + return GetRef(op); + } else { + return CallPatternNode::make(new_op, call_args, op->attrs, op->type_args); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const ExprPatternNode* op) { + return GetRef(op); +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TupleGetItemPatternNode* op) { + auto new_tuple = Mutate(op->tuple); + if (new_tuple.same_as(op->tuple)) { + return GetRef(op); + } else { + return TupleGetItemPatternNode::make(op->tuple, op->index); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TuplePatternNode* op) { + bool unchanged = true; + tvm::Array fields; + for (auto field : op->fields) { + auto new_field = Mutate(field); + fields.push_back(new_field); + unchanged &= field.same_as(new_field); + } + if (unchanged) { + return GetRef(op); + } else { + return TuplePatternNode::make(fields); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TypePatternNode* op) { + auto new_pattern = Mutate(op->pattern); + if (new_pattern.same_as(op->pattern)) { + return GetRef(op); + } else { + return TypePatternNode::make(new_pattern, op->type); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const VarPatternNode* op) { + return GetRef(op); +} + +DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { + return GetRef(op); +} + +// Prepare + +class DFPatternPrepare : protected DFPatternMutator { + public: + DFPattern Prepare(const DFPattern& pattern) { return Mutate(pattern); } + DFPattern VisitDFPattern_(const CallPatternNode* op) { + auto post = DFPatternMutator::VisitDFPattern_(op); + auto* post_node = post.as(); + if (auto* expr_pattern = post_node->op.as()) { + if (auto* op_node = expr_pattern->expr.as()) { + if ((op_node->name == "add") || (op_node->name == "multiply")) { + tvm::Array call_args; + for (auto it = post_node->args.rbegin(); it != post_node->args.rend(); ++it) { + call_args.push_back(*it); + } + return AltPatternNode::make( + post, CallPatternNode::make(post_node->op, call_args, post_node->attrs, + post_node->type_args)); + } + } + } + return post; + } +}; + +TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) { + return DFPatternMatcher().Match(DFPatternPrepare().Prepare(pattern), expr); +}); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 1344da25a1bc..3a4d6b248a52 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -1,8 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. import tvm from tvm import relay from tvm.relay.df_pattern import * +# NB: 1 corresponds to the C++ enum that specicfies this +# we loose the type safety due to the Python/C++ calling +# convention. +K_ELEMWISE = 1 +## NODE TESTS def test_expr_pattern(): ep = ExprPattern(relay.var('x', shape=(4, 1))) print(ep) @@ -42,22 +63,182 @@ def test_TypePattern(): ty_pat = has_type(relay.TensorType((10, 10), "float32")) print(ty_pat) -# NB: 1 corresponds to the C++ enum that specicfies this -# we loose the type safety due to the Python/C++ calling -# convention. -K_ELEMWISE = 1 def test_AttrPattern(): op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) op_pat = op(wildcard(), wildcard()) print(op_pat) +## MATCHER TESTS + +def test_match_op(): + assert is_op('add').match(relay.op.op.get("add")) + +def test_no_match_op(): + assert not is_op('add').match(relay.op.op.get("subtract")) + +def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + +def test_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(is_input("x"), is_input("y")) + assert add_pattern.match(x + y) + assert add_pattern.match(y + x) + mul_pattern = is_op('multiply')(is_input("x"), is_input("y")) + assert mul_pattern.match(x * y) + assert mul_pattern.match(y * x) + +def test_no_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('subtract')(is_input("x"), is_input("y")) + assert add_pattern.match(x - y) + assert not add_pattern.match(y - x) + add_pattern = is_op('divide')(is_input("x"), is_input("y")) + assert add_pattern.match(x / y) + assert not add_pattern.match(y / x) + +def test_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert add_pattern.match(x + y) + +def test_no_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert not add_pattern.match(x - y) + +def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) + +def test_no_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard())) + assert not tuple_pattern.match(relay.expr.Tuple((x,y,z))) + +def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1)) + +def test_no_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"))) + tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2)) + +def test_match_type(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert ty_pat.match(x) + +def test_no_match_type(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert not ty_pat.match(x) + +def test_match_attr(): + op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert op_pat.match(x + y) + +def test_no_match_attr(): + op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + +def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +def test_no_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(leaky_relu) + assert not diamond.match(relu) + +def test_match_fake_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + if __name__ == "__main__": - test_expr_pattern() - test_var_pattern() - test_wildcard_pattern() - test_CallPattern() - test_TuplePattern() - test_TupleGetItemPattern() - test_AltPattern() - test_TypePattern() - test_AttrPattern() + test_match_op() + test_no_match_op() + test_match_op_or() + test_match_call() + test_no_match_call() + test_match_call_commutive() + test_no_match_call_commutive() + test_match_tuple() + test_no_match_tuple() + test_match_type() + test_no_match_type() + test_match_attr() + test_no_match_attr() + test_match_diamond() + test_no_match_diamond() + test_match_fake_diamond() From 8755b8b0c78d77597b63ab4cfe15830729d22be4 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 3 Apr 2020 12:13:11 -0700 Subject: [PATCH 05/46] Pattern Rewriter --- include/tvm/relay/dataflow_matcher.h | 27 ++++++++++ python/tvm/relay/df_pattern/__init__.py | 15 ++++-- src/relay/ir/dataflow_matcher.cc | 71 +++++++++++++++++++++---- tests/python/relay/test_df_pattern.py | 12 +++++ 4 files changed, 113 insertions(+), 12 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index cef38f71cf9f..4eea4befb1b5 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -133,6 +133,33 @@ class DFPatternMutator : public DFPatternFunctor { std::unordered_map memo_; }; +class DFPatternCallback; +/*! + * \brief Base type of all dataflow pattern callbacks. + * \sa DFPatternCallback + */ +class DFPatternCallbackNode : public Object { + public: + DFPattern pattern_; + PackedFunc function_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + TVM_DLL static DFPatternCallback make(DFPattern pattern, PackedFunc callback); + + static constexpr const char* _type_key = "DFPatternCallbackNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); +}; + +/*! + * \brief Managed reference to dataflow pattern callbacks. + * \sa DFPatternCallbackNode + */ +class DFPatternCallback : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); +}; + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 14d89d08b07c..6453dadf2677 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -17,13 +17,11 @@ """The Relay Pattern Language and tooling.""" from ...ir.base import Node from ...ir import make_node +from ...runtime import Object from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi -def match(pattern, expr): - return ffi.match(pattern, expr) - def register_df_node(type_key=None): """Register a Relay node type. @@ -214,6 +212,11 @@ def __init__(self, pattern, attrs): self.__init_handle_by_constructor__( ffi.AttrPattern, pattern, attrs) +class DFPatternCallback(Object): + def __init__(self, pattern, callback): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback) + def is_input(name="") -> DFPattern: return VarPattern(name) @@ -233,3 +236,9 @@ def has_attr(attr_name, attr_value, pattern=None): if pattern is None: pattern = wildcard() return patter.has_attr(attr_name, attr_value) + +def match(pattern, expr): + return ffi.match(pattern, expr) + +def rewrite(callbacks, expr): + return ffi.rewrite(callbacks, expr) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index fc8f1a585500..b218ffeda15f 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -23,15 +23,21 @@ */ #include +#include #include #include namespace tvm { namespace relay { +// Pattern Matcher + class DFPatternMatcher : public DFPatternFunctor { public: bool Match(const DFPattern& pattern, const Expr& expr); + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; @@ -42,15 +48,19 @@ class DFPatternMatcher : public DFPatternFunctor memo_; }; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + return VisitDFPattern(pattern, expr); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { if (memo_.count(pattern)) { return expr.same_as(memo_[pattern]); } else { - auto out = VisitDFPattern(pattern, expr); + auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { memo_[pattern] = expr; } @@ -59,7 +69,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { } bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { - return Match(op->left, expr) || Match(op->right, expr); + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { bool matches = false; @@ -99,10 +109,10 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex bool matches = false; if (const auto* call_node = expr.as()) { if (op->args.size() == call_node->args.size()) { - matches = Match(op->op, call_node->op); + matches = VisitDFPattern(op->op, call_node->op); size_t i = 0; while (matches && i < op->args.size()) { - matches &= Match(op->args[i], call_node->args[i]); + matches &= VisitDFPattern(op->args[i], call_node->args[i]); ++i; } } @@ -115,8 +125,8 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_get_item_node = expr.as()) { - matches = - (op->index == tuple_get_item_node->index) && Match(op->tuple, tuple_get_item_node->tuple); + matches = (op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); } return matches; } @@ -127,7 +137,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e matches = true; size_t i = 0; while (matches && i < op->fields.size()) { - matches &= Match(op->fields[i], tuple_node->fields[i]); + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); ++i; } } @@ -145,7 +155,7 @@ Expr InferType(const Expr& expr) { } bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); - return (StructuralEqual()(op->type, expr_type)) && Match(op->pattern, expr); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { bool matches = false; @@ -286,5 +296,48 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern patter return DFPatternMatcher().Match(DFPatternPrepare().Prepare(pattern), expr); }); +// Rewrite + +DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) { + ObjectPtr n = make_object(); + n->pattern_ = std::move(pattern); + n->function_ = std::move(function); + return DFPatternCallback(n); +} + +TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") +.set_body_typed(DFPatternCallbackNode::make); + +class PatternRewriter : public ExprMutator { + public: + PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} + Expr Rewrite(const Expr& pre) { + return this->VisitExpr(pre); + } + + protected: + Expr VisitExpr(const Expr& pre) override { + auto post = ExprMutator::VisitExpr(pre); + Expr out = post; + for (auto& callback : callbacks_) { + if (auto* callback_node = callback.as()) { + if (matcher_.Match(callback_node->pattern_, out)) { + out = callback_node->function_(pre, out); + } + } + } + return out; + } + DFPatternMatcher matcher_; + Array callbacks_; +}; + +TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite") +.set_body_typed([](Array callbacks, Expr expr) { + return PatternRewriter(callbacks).Rewrite(expr); +}); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 3a4d6b248a52..3989fb570569 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -225,6 +225,17 @@ def test_match_fake_diamond(): # Check assert not diamond.match(out) +def test_match_rewrite(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + sub_pattern = is_op('subtract')(wildcard(), wildcard()) + def add_to_sub(pre, post): + return post.args[0] - post.args[1] + out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y) + assert sub_pattern.match(out) + + if __name__ == "__main__": test_match_op() test_no_match_op() @@ -242,3 +253,4 @@ def test_match_fake_diamond(): test_match_diamond() test_no_match_diamond() test_match_fake_diamond() + test_match_rewrite() From 475ef79c4d611299a9718dbd369d4a210b804721 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 6 Apr 2020 14:14:36 -0700 Subject: [PATCH 06/46] add batchnorm tests, commutivity extensions, realize I need to break the diamond for more complicated graphs --- python/tvm/relay/df_pattern/__init__.py | 14 ++ src/relay/ir/dataflow_matcher.cc | 213 +++++++++++++++++++----- src/relay/ir/expr_functor.cc | 7 +- tests/python/relay/test_df_pattern.py | 156 +++++++++++++++-- 4 files changed, 322 insertions(+), 68 deletions(-) diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 6453dadf2677..40480ed59311 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -47,6 +47,18 @@ def __call__(self, *args): def __or__(self, other): return AltPattern(self, other) + def __add__(self, other): + return is_op("add")(self, other) + + def __sub__(self, other): + return is_op("subtract")(self, other) + + def __mul__(self, other): + return is_op("multiply")(self, other) + + def __truediv__(self, other): + return is_op("divide")(self, other) + def has_attr(self, attr_name, attr_value): attrs = make_node("DictAttrs", **{attr_name:attr_value}) return AttrPattern(self, attrs) @@ -241,4 +253,6 @@ def match(pattern, expr): return ffi.match(pattern, expr) def rewrite(callbacks, expr): + if isinstance(callbacks, DFPatternCallback): + callbacks = [callbacks] return ffi.rewrite(callbacks, expr) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index b218ffeda15f..17a15e3f05ad 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -57,15 +57,15 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { } bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { - if (memo_.count(pattern)) { - return expr.same_as(memo_[pattern]); - } else { - auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); - if (out) { - memo_[pattern] = expr; - } - return out; - } +// if (memo_.count(pattern)) { +// return expr.same_as(memo_[pattern]); +// } else { + auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); +// if (out) { +// memo_[pattern] = expr; +// } + return out; +// } } bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { @@ -105,19 +105,165 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } return matches; } + +Array reverse(const Array args) { + Array new_args; + for (auto it = args.rbegin(); it != args.rend(); ++it) { + new_args.push_back(*it); + } + return new_args; +} + bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { - bool matches = false; - if (const auto* call_node = expr.as()) { - if (op->args.size() == call_node->args.size()) { - matches = VisitDFPattern(op->op, call_node->op); - size_t i = 0; - while (matches && i < op->args.size()) { - matches &= VisitDFPattern(op->args[i], call_node->args[i]); + auto match_args = [this](const Array pattern_args, const Array expr_args) { + bool matches = true; + size_t i = 0; + if (pattern_args.size() == expr_args.size()) { + while (matches && i < pattern_args.size()) { + matches &= VisitDFPattern(pattern_args[i], expr_args[i]); ++i; } + } else { + matches = false; + } + return matches; + }; + + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + if (match_args(op->args, call_node->args)) { + return true; + } + if (auto* op_node = get_op_node(op)) { + if ((op_node->name == "add")) { + if (match_args(reverse(op->args), call_node->args)) { + return true; + } + } else if ((op_node->name == "multiply")) { + if (match_args(reverse(op->args), call_node->args)) { + return true; + } + } + } + } else { + if (const OpNode* op_node = get_op_node(op)) { + if (op_node->name == "divide") { + if (auto* arg_node = op->args[0].as()) { + if (const OpNode* arg_op = get_op_node(arg_node)) { + if (arg_op->name == "multiply") { + auto associate_div_mul = [this, &op, &arg_node, &expr]() { + auto div1 = CallPatternNode::make(op->op, {arg_node->args[1], op->args[1]}, + op->attrs, op->type_args); + auto mul1 = CallPatternNode::make(arg_node->op, {arg_node->args[0], div1}, + arg_node->attrs, arg_node->type_args); + auto div2 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]}, + op->attrs, op->type_args); + auto mul2 = CallPatternNode::make(arg_node->op, {arg_node->args[1], div2}, + arg_node->attrs, arg_node->type_args); + return VisitDFPattern(mul1, expr)|| VisitDFPattern(mul2, expr); + }; + + if (const OpNode* expr_op_node = call_node->op.as()) { + if (expr_op_node->name == "multiply") { + if (auto* input_call_node = call_node->args[0].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "divide") { + return associate_div_mul(); + } + } + } + if (auto* input_call_node = call_node->args[1].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "divide") { + return associate_div_mul(); + } + } + } + } + } + } + } + } + } else if (op_node->name == "multiply") { + if (auto* arg_node = op->args[0].as()) { + if (const OpNode* arg_op = get_op_node(arg_node)) { + if (arg_op->name == "divide") { + auto associate_mul_div = [this, &op, &arg_node, &expr]() { + auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]}, + op->attrs, op->type_args); + auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]}, + arg_node->attrs, arg_node->type_args); + return VisitDFPattern(div1, expr); + }; + + if (const OpNode* expr_op_node = call_node->op.as()) { + if (expr_op_node->name == "divide") { + if (auto* input_call_node = call_node->args[0].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "multiply") { + return associate_mul_div(); + } + } + } + if (auto* input_call_node = call_node->args[1].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "multiply") { + return associate_mul_div(); + } + } + } + } + } + } + } + } + if (auto* arg_node = op->args[1].as()) { + if (const OpNode* arg_op = get_op_node(arg_node)) { + if (arg_op->name == "divide") { + auto associate_mul_div = [this, &op, &arg_node, &expr]() { + auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[0]}, + op->attrs, op->type_args); + auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]}, + arg_node->attrs, arg_node->type_args); + return VisitDFPattern(div1, expr); + }; + + if (const OpNode* expr_op_node = call_node->op.as()) { + if (expr_op_node->name == "divide") { + if (auto* input_call_node = call_node->args[0].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "multiply") { + return associate_mul_div(); + } + } + } + if (auto* input_call_node = call_node->args[1].as()) { + if (const OpNode* input_op_node = input_call_node->op.as()) { + if (input_op_node->name == "multiply") { + return associate_mul_div(); + } + } + } + } + } + } + } + } + } + } } } - return matches; + return false; } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { return op->expr == expr; @@ -267,33 +413,8 @@ DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { return GetRef(op); } -// Prepare - -class DFPatternPrepare : protected DFPatternMutator { - public: - DFPattern Prepare(const DFPattern& pattern) { return Mutate(pattern); } - DFPattern VisitDFPattern_(const CallPatternNode* op) { - auto post = DFPatternMutator::VisitDFPattern_(op); - auto* post_node = post.as(); - if (auto* expr_pattern = post_node->op.as()) { - if (auto* op_node = expr_pattern->expr.as()) { - if ((op_node->name == "add") || (op_node->name == "multiply")) { - tvm::Array call_args; - for (auto it = post_node->args.rbegin(); it != post_node->args.rend(); ++it) { - call_args.push_back(*it); - } - return AltPatternNode::make( - post, CallPatternNode::make(post_node->op, call_args, post_node->attrs, - post_node->type_args)); - } - } - } - return post; - } -}; - TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) { - return DFPatternMatcher().Match(DFPatternPrepare().Prepare(pattern), expr); + return DFPatternMatcher().Match(pattern, expr); }); // Rewrite @@ -310,7 +431,7 @@ TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") .set_body_typed(DFPatternCallbackNode::make); -class PatternRewriter : public ExprMutator { +class PatternRewriter : protected MixedModeMutator { public: PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} Expr Rewrite(const Expr& pre) { @@ -318,8 +439,8 @@ class PatternRewriter : public ExprMutator { } protected: - Expr VisitExpr(const Expr& pre) override { - auto post = ExprMutator::VisitExpr(pre); + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); Expr out = post; for (auto& callback : callbacks_) { if (auto* callback_node = callback.as()) { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 18fd1c711dd0..684dae7cc481 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -142,7 +142,8 @@ void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {} void MixedModeMutator::VisitLeaf(const Expr& expr) { if (!memo_.count(expr)) { - this->DispatchVisitExpr(expr); + Expr ret = this->DispatchVisitExpr(expr); + memo_[expr] = ret; } } @@ -163,9 +164,7 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) { return memo_[expr]; } else { ExpandDataflow(expr, fcheck_visited, fvisit_leaf); - Expr ret = this->DispatchVisitExpr(expr); - memo_[expr] = ret; - return ret; + return memo_[expr]; } } diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 3989fb570569..306950ca75eb 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -17,6 +17,7 @@ import tvm from tvm import relay from tvm.relay.df_pattern import * +import numpy as np # NB: 1 corresponds to the C++ enum that specicfies this # we loose the type safety due to the Python/C++ calling @@ -225,7 +226,7 @@ def test_match_fake_diamond(): # Check assert not diamond.match(out) -def test_match_rewrite(): +def test_rewrite(): x = relay.var('x') y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) @@ -235,22 +236,141 @@ def add_to_sub(pre, post): out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y) assert sub_pattern.match(out) +def fuse_batchnorm(pre, post): + def left_right_call(post): + if isinstance(post.args[0], relay.Call): + return (post.args[1], post.args[0]) + else: + return (post.args[0], post.args[1]) + + beta, post = left_right_call(post) + assert isinstance(post, relay.Call) + + if post.op == relay.op.get("divide"): + numerator = post.args[0] + denominator = post.args[1] + gamma, numerator = left_right_call(numerator) + elif post.op == relay.op.get("multiply"): + gamma, quotient = left_right_call(post) + numerator = quotient.args[0] + denominator = quotient.args[1] + else: + raise "Found unexcepted op" + + x = numerator.args[0] + mean = numerator.args[1] + + var = denominator.args[0].args[0] + eps = denominator.args[0].args[1] + + out = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item()) + return out[0] + +def test_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + +def test_no_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() + fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta + + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), fake_BN) + assert tvm.ir.structural_equal(out, fake_BN) + +def test_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2) + + bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0] + bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + +def test_partial_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2) + + bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + +def test_fuse_batchnorm_commutation(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() + #commute add + BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + + # associate multiply/divide + BN = (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) * gamma + beta + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + + # associate divide/multiply + BN_pattern = wildcard() * ((wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard())) + wildcard() + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) if __name__ == "__main__": - test_match_op() - test_no_match_op() - test_match_op_or() - test_match_call() - test_no_match_call() - test_match_call_commutive() - test_no_match_call_commutive() - test_match_tuple() - test_no_match_tuple() - test_match_type() - test_no_match_type() - test_match_attr() - test_no_match_attr() - test_match_diamond() - test_no_match_diamond() - test_match_fake_diamond() - test_match_rewrite() + #test_match_op() + #test_no_match_op() + #test_match_op_or() + #test_match_call() + #test_no_match_call() + #test_match_call_commutive() + #test_no_match_call_commutive() + #test_match_tuple() + #test_no_match_tuple() + #test_match_type() + #test_no_match_type() + #test_match_attr() + #test_no_match_attr() + #test_match_diamond() + #test_no_match_diamond() + #test_match_fake_diamond() + #test_rewrite() + #test_fuse_batchnorm() + #test_no_fuse_batchnorm() + #test_fuse_double_batchnorm() + #test_partial_fuse_double_batchnorm() + test_fuse_batchnorm_commutation() From def390e9b8802f136bc0ec53f5676f2e9e600433 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 6 Apr 2020 16:35:13 -0700 Subject: [PATCH 07/46] use watermark-based memoization resets to fix diamond matching --- src/relay/ir/dataflow_matcher.cc | 80 +++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 17a15e3f05ad..3e8311afef27 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -48,7 +48,9 @@ class DFPatternMatcher : public DFPatternFunctor memo_; + std::vector matched_nodes_; }; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { @@ -56,16 +58,26 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { -// if (memo_.count(pattern)) { -// return expr.same_as(memo_[pattern]); -// } else { - auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); -// if (out) { -// memo_[pattern] = expr; -// } - return out; -// } + if (memo_.count(pattern)) { + return expr.same_as(memo_[pattern]); + } else { + auto watermark = matched_nodes_.size(); + auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern] = expr; + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } } bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { @@ -115,19 +127,7 @@ Array reverse(const Array args) { } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { - auto match_args = [this](const Array pattern_args, const Array expr_args) { - bool matches = true; - size_t i = 0; - if (pattern_args.size() == expr_args.size()) { - while (matches && i < pattern_args.size()) { - matches &= VisitDFPattern(pattern_args[i], expr_args[i]); - ++i; - } - } else { - matches = false; - } - return matches; - }; + auto watermark = matched_nodes_.size(); auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { @@ -141,6 +141,24 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (const auto* call_node = expr.as()) { auto matches_op = VisitDFPattern(op->op, call_node->op); if (matches_op) { + auto watermark2 = matched_nodes_.size(); + auto match_args = [this, &watermark2](const Array pattern_args, const Array expr_args) { + bool matches = true; + size_t i = 0; + if (pattern_args.size() == expr_args.size()) { + while (matches && i < pattern_args.size()) { + matches &= VisitDFPattern(pattern_args[i], expr_args[i]); + ++i; + } + } else { + matches = false; + } + if (!matches) { + ClearMap(watermark2); + } + return matches; + }; + if (match_args(op->args, call_node->args)) { return true; } @@ -156,12 +174,13 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } } } else { + ClearMap(watermark); if (const OpNode* op_node = get_op_node(op)) { if (op_node->name == "divide") { if (auto* arg_node = op->args[0].as()) { if (const OpNode* arg_op = get_op_node(arg_node)) { if (arg_op->name == "multiply") { - auto associate_div_mul = [this, &op, &arg_node, &expr]() { + auto associate_div_mul = [this, &op, &arg_node, &expr, &watermark]() { auto div1 = CallPatternNode::make(op->op, {arg_node->args[1], op->args[1]}, op->attrs, op->type_args); auto mul1 = CallPatternNode::make(arg_node->op, {arg_node->args[0], div1}, @@ -170,7 +189,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex op->attrs, op->type_args); auto mul2 = CallPatternNode::make(arg_node->op, {arg_node->args[1], div2}, arg_node->attrs, arg_node->type_args); - return VisitDFPattern(mul1, expr)|| VisitDFPattern(mul2, expr); + auto out = VisitDFPattern(mul1, expr); + if (!out) { + ClearMap(watermark); + out = VisitDFPattern(mul2, expr); + } + return out; }; if (const OpNode* expr_op_node = call_node->op.as()) { @@ -455,10 +479,12 @@ class PatternRewriter : protected MixedModeMutator { Array callbacks_; }; -TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite") -.set_body_typed([](Array callbacks, Expr expr) { +Expr RewritePatterns(Array callbacks, Expr expr) { return PatternRewriter(callbacks).Rewrite(expr); -}); +} + +TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite") +.set_body_typed(RewritePatterns); } // namespace relay } // namespace tvm From ca6885f0d34ecb717a35dc565515cef168103d8f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 8 Apr 2020 10:01:28 -0700 Subject: [PATCH 08/46] rough dominator matcher --- include/tvm/relay/dataflow_matcher.h | 3 ++ include/tvm/relay/dataflow_pattern.h | 52 +++++++++++++++++++ python/tvm/relay/df_pattern/__init__.py | 20 ++++++++ src/relay/ir/dataflow_matcher.cc | 67 +++++++++++++++++++++++-- src/relay/ir/dataflow_pattern.cc | 19 +++++++ tests/python/relay/test_df_pattern.py | 43 +++++++++++++++- 6 files changed, 197 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 4eea4befb1b5..d6943929a5b1 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -85,6 +85,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -105,6 +106,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); @@ -122,6 +124,7 @@ class DFPatternMutator : public DFPatternFunctor { DFPattern VisitDFPattern_(const AltPatternNode* op) override; DFPattern VisitDFPattern_(const AttrPatternNode* op) override; DFPattern VisitDFPattern_(const CallPatternNode* op) override; + DFPattern VisitDFPattern_(const DominatorPatternNode* op) override; DFPattern VisitDFPattern_(const ExprPatternNode* op) override; DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override; DFPattern VisitDFPattern_(const TuplePatternNode* op) override; diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index beac285e3a4d..5845778924e6 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -276,6 +276,25 @@ class WildcardPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); }; +/*! + * \brief Null Pattern. + */ +class NullPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.df_pattern.NullPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(NullPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches anything. + */ +class NullPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(NullPattern, DFPattern, NullPatternNode); +}; + class TypePattern; /*! * \brief Pattern for Types. @@ -336,6 +355,39 @@ class AttrPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); }; +class DominatorPattern; +/*! + * \brief Pattern for Types. + */ +class DominatorPatternNode : public DFPatternNode { + public: + /*! \brief The parent. */ + DFPattern parent; + /*! \brief The path. */ + DFPattern path; + /*! \brief The child. */ + DFPattern child; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("parent", &parent); + v->Visit("path", &path); + v->Visit("child", &child); + } + + TVM_DLL static DominatorPattern make(DFPattern parent, DFPattern path, DFPattern child); + + static constexpr const char* _type_key = "relay.df_pattern.DominatorPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a variable length dominator path + */ +class DominatorPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 40480ed59311..30a5450ae947 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -224,6 +224,23 @@ def __init__(self, pattern, attrs): self.__init_handle_by_constructor__( ffi.AttrPattern, pattern, attrs) +@register_df_node +class DominatorPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + parent: tvm.relay.df_pattern.DFPattern + The root of domination + path: tvm.relay.df_pattern.DFPattern + The fuzzy path pattern between parent and child + child: tvm.relay.df_pattern.DFPattern + The last node in the domination + """ + def __init__(self, parent, path, child): + self.__init_handle_by_constructor__( + ffi.DominatorPattern, parent, path, child) + class DFPatternCallback(Object): def __init__(self, pattern, callback): self.__init_handle_by_constructor__( @@ -256,3 +273,6 @@ def rewrite(callbacks, expr): if isinstance(callbacks, DFPatternCallback): callbacks = [callbacks] return ffi.rewrite(callbacks, expr) + +def dominates(parent, path, child): + return DominatorPattern(parent, path, child) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3e8311afef27..c7c98a2d40e6 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -41,6 +41,7 @@ class DFPatternMatcher : public DFPatternFunctor matched_nodes_; }; +class DominatorMatcher : public DFPatternMatcher { + public: + DominatorMatcher(const DominatorPatternNode* dominator) : dominator_(dominator) {} + bool Dominates(const Expr& expr) { + found_child = DFPatternMatcher::VisitDFPattern(dominator_->child, expr); + if (found_child) { + return false; + } + return false; + } + + const std::unordered_map& GetMemo() { return memo_; } + const std::vector GetMatched() { return matched_nodes_; } + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override { + std::cout << "visiting " << pattern << "\n\t -> " << expr << std::endl; + if (DFPatternMatcher::VisitDFPattern(pattern, expr)) { + return true; + } else if (found_child) { + if (DFPatternMatcher::VisitDFPattern(dominator_->parent, expr)) { + return true; + } else { + return DFPatternMatcher::VisitDFPattern(dominator_->path, expr); + } + } + return false; + } + const DominatorPatternNode* dominator_; + bool found_child = false; +}; + bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); + matched_nodes_.clear(); return VisitDFPattern(pattern, expr); } @@ -163,11 +196,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return true; } if (auto* op_node = get_op_node(op)) { - if ((op_node->name == "add")) { - if (match_args(reverse(op->args), call_node->args)) { - return true; - } - } else if ((op_node->name == "multiply")) { + if ((op_node->name == "add") || (op_node->name == "multiply")) { if (match_args(reverse(op->args), call_node->args)) { return true; } @@ -175,6 +204,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } } else { ClearMap(watermark); + // TODO(mbrookhart): This is nasty. Find a cleaner way to do this if (const OpNode* op_node = get_op_node(op)) { if (op_node->name == "divide") { if (auto* arg_node = op->args[0].as()) { @@ -289,6 +319,19 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } return false; } +bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { + DominatorMatcher visitor(op); + if (visitor.Dominates(expr)) { + const auto new_memo = visitor.GetMemo(); + const auto new_matched = visitor.GetMatched(); + for (const auto &pattern : new_matched) { + matched_nodes_.push_back(pattern); + memo_[pattern] = new_memo.at(pattern); + } + return true; + } + return false; +} bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { return op->expr == expr; } @@ -341,6 +384,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } + + // DFPatternMutator DFPattern DFPatternMutator::Mutate(const DFPattern& pattern) { return VisitDFPattern(pattern); } @@ -392,6 +437,18 @@ DFPattern DFPatternMutator::VisitDFPattern_(const CallPatternNode* op) { } } +DFPattern DFPatternMutator::VisitDFPattern_(const DominatorPatternNode* op) { + auto new_parent = Mutate(op->parent); + auto new_path = Mutate(op->path); + auto new_child = Mutate(op->child); + if (op->parent.same_as(new_child) && op->parent.same_as(new_child) && + op->parent.same_as(new_child)) { + return GetRef(op); + } else { + return DominatorPatternNode::make(new_parent, new_path, new_child); + } +} + DFPattern DFPatternMutator::VisitDFPattern_(const ExprPatternNode* op) { return GetRef(op); } diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 5a34ae9d832d..8568205541dc 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -193,5 +193,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; }); +DominatorPattern DominatorPatternNode::make(DFPattern parent, DFPattern path, DFPattern child) { + ObjectPtr n = make_object(); + n->parent = std::move(parent); + n->path = std::move(path); + n->child = std::move(child); + return DominatorPattern(n); +} + +TVM_REGISTER_NODE_TYPE(DominatorPatternNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.DominatorPattern").set_body_typed(DominatorPatternNode::make); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child + << ")"; + }); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 306950ca75eb..7d8fd0f97213 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -226,6 +226,26 @@ def test_match_fake_diamond(): # Check assert not diamond.match(out) + +def test_match_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_elemwise = wildcard().has_attr("TOpPattern", K_ELEMWISE) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + def test_rewrite(): x = relay.var('x') y = relay.var('y') @@ -236,7 +256,25 @@ def add_to_sub(pre, post): out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y) assert sub_pattern.match(out) -def fuse_batchnorm(pre, post): +def test_not_fuse_multi_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + out = out + conv2d + # Check + assert not diamond.match(out) + +def (pre, post): def left_right_call(post): if isinstance(post.args[0], relay.Call): return (post.args[1], post.args[0]) @@ -373,4 +411,5 @@ def test_fuse_batchnorm_commutation(): #test_no_fuse_batchnorm() #test_fuse_double_batchnorm() #test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() + #test_fuse_batchnorm_commutation() + test_match_dominator() From 523131808d2edf1d91e97f78a5f6e9d24bf32f63 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 8 Apr 2020 12:01:51 -0700 Subject: [PATCH 09/46] move dataflow functor/visitor/mutator to separate file --- include/tvm/relay/dataflow_functor.h | 159 +++++++++ src/relay/ir/dataflow_functor.cc | 470 ++++++++++++++++++++++++++ src/relay/ir/dataflow_matcher.cc | 110 ------ tests/python/relay/test_df_pattern.py | 2 +- 4 files changed, 630 insertions(+), 111 deletions(-) create mode 100644 include/tvm/relay/dataflow_functor.h create mode 100644 src/relay/ir/dataflow_functor.cc diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h new file mode 100644 index 000000000000..73cccd15c3bb --- /dev/null +++ b/include/tvm/relay/dataflow_functor.h @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_ +#define TVM_RELAY_DATAFLOW_FUNCTOR_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + CHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + return vtable; + } +}; + +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const AltPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const DominatorPatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + + protected: + std::unordered_set visited_; +}; + +class DFPatternMutator : public DFPatternFunctor { + public: + virtual DFPattern Mutate(const DFPattern& pattern); + DFPattern VisitDFPattern(const DFPattern& pattern) override; + DFPattern VisitDFPattern_(const AltPatternNode* op) override; + DFPattern VisitDFPattern_(const AttrPatternNode* op) override; + DFPattern VisitDFPattern_(const CallPatternNode* op) override; + DFPattern VisitDFPattern_(const DominatorPatternNode* op) override; + DFPattern VisitDFPattern_(const ExprPatternNode* op) override; + DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override; + DFPattern VisitDFPattern_(const TuplePatternNode* op) override; + DFPattern VisitDFPattern_(const TypePatternNode* op) override; + DFPattern VisitDFPattern_(const VarPatternNode* op) override; + DFPattern VisitDFPattern_(const WildcardPatternNode* op) override; + + protected: + std::unordered_map memo_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_DATAFLOW_FUNCTOR_H_ diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc new file mode 100644 index 000000000000..1d29faafa18e --- /dev/null +++ b/src/relay/ir/dataflow_functor.cc @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const AltPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + for (auto arg : op->args) { + VisitDFPattern(arg); + } +} +void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { + VisitDFPattern(op->parent); + VisitDFPattern(op->path); + VisitDFPattern(op->child); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + for (auto field : op->fields) { + VisitDFPattern(field); + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +// DFPatternMutator + +DFPattern DFPatternMutator::Mutate(const DFPattern& pattern) { return VisitDFPattern(pattern); } + +DFPattern DFPatternMutator::VisitDFPattern(const DFPattern& pattern) { + auto it = this->memo_.find(pattern); + if (it != this->memo_.end()) { + return it->second; + } else { + auto new_pattern = DFPatternFunctor::VisitDFPattern(pattern); + memo_[pattern] = new_pattern; + return new_pattern; + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const AltPatternNode* op) { + auto new_left = Mutate(op->left); + auto new_right = Mutate(op->right); + + if (new_left.same_as(op->left) && new_right.same_as(op->right)) { + return GetRef(op); + } else { + return AltPatternNode::make(new_left, new_right); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const AttrPatternNode* op) { + auto new_pattern = Mutate(op->pattern); + if (new_pattern.same_as(op->pattern)) { + return GetRef(op); + } else { + return AttrPatternNode::make(new_pattern, op->attrs); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const CallPatternNode* op) { + auto new_op = Mutate(op->op); + bool unchanged = op->op.same_as(new_op); + tvm::Array call_args; + for (auto arg : op->args) { + auto new_arg = Mutate(arg); + call_args.push_back(new_arg); + unchanged &= arg.same_as(new_arg); + } + if (unchanged) { + return GetRef(op); + } else { + return CallPatternNode::make(new_op, call_args, op->attrs, op->type_args); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const DominatorPatternNode* op) { + auto new_parent = Mutate(op->parent); + auto new_path = Mutate(op->path); + auto new_child = Mutate(op->child); + if (op->parent.same_as(new_child) && op->parent.same_as(new_child) && + op->parent.same_as(new_child)) { + return GetRef(op); + } else { + return DominatorPatternNode::make(new_parent, new_path, new_child); + } +} + + +DFPattern DFPatternMutator::VisitDFPattern_(const ExprPatternNode* op) { + return GetRef(op); +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TupleGetItemPatternNode* op) { + auto new_tuple = Mutate(op->tuple); + if (new_tuple.same_as(op->tuple)) { + return GetRef(op); + } else { + return TupleGetItemPatternNode::make(op->tuple, op->index); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TuplePatternNode* op) { + bool unchanged = true; + tvm::Array fields; + for (auto field : op->fields) { + auto new_field = Mutate(field); + fields.push_back(new_field); + unchanged &= field.same_as(new_field); + } + if (unchanged) { + return GetRef(op); + } else { + return TuplePatternNode::make(fields); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const TypePatternNode* op) { + auto new_pattern = Mutate(op->pattern); + if (new_pattern.same_as(op->pattern)) { + return GetRef(op); + } else { + return TypePatternNode::make(new_pattern, op->type); + } +} + +DFPattern DFPatternMutator::VisitDFPattern_(const VarPatternNode* op) { + return GetRef(op); +} + +DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { + return GetRef(op); +} + +//IndexedGraph + +template +struct Node { + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + const T ref_; + const size_t index_; + std::vector>> outputs_; +}; + +template +struct IndexedGraph { + std::unordered_map>, ObjectHash, ObjectEqual> node_map_; + std::vector>> topological_order_; +}; + +IndexedGraph +CreateIndexedGraph(const Expr& expr) { + using NodePtr = std::shared_ptr>; + class Creator : public MixedModeVisitor { + public: + IndexedGraph CreateGraph(const Expr& expr) { + VisitExpr(expr); + return std::move(graph_); + } + void Create(const Expr& expr) { VisitExpr(expr); } + + protected: + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + auto node = std::make_shared>(expr, index_++); + graph_.node_map_[expr] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + class Annotator : public ExprFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + for (const auto& node : graph_.topological_order_) { + ExprFunctor::VisitExpr(node->ref_, nullptr); + } + return std::move(graph_); + } + + void VisitExpr(const Expr& expr, NodePtr parent) override { + if (parent) { + graph_.node_map_[expr]->outputs_.push_back(parent); + } + } + + protected: + IndexedGraph graph_; + void VisitExpr_(const VarNode* op, NodePtr parent) override { + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation); + } + } + + void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + + void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + + void VisitExpr_(const TupleNode* op, NodePtr parent) override { + for (auto field : op->fields) { + this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const FunctionNode* op, NodePtr parent) override { + for (auto param : op->params) { + this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + } + + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const CallNode* op, NodePtr parent) override { + this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const LetNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const IfNode* op, NodePtr parent) override { + this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + + void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { + this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefReadNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { + for (const Type& t : op->inputs) { + this->VisitType(t); + } + this->VisitType(op->belong_to); + } + + void VisitExpr_(const MatchNode* op, NodePtr parent) override { + this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); + for (const Clause& c : op->clauses) { + this->VisitClause(c, graph_.node_map_[GetRef(op)]); + } + } + + void VisitClause(const Clause& op, NodePtr parent) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs, parent); + } + + void VisitPattern(const Pattern& p) { return; } + + void VisitType(const Type& t) { return; } + }; + return Annotator(Creator().CreateGraph(expr)).Annotate(); +} +// +// IndexedGraph +// CreateIndexedGraph(const Expr& expr) { +// using NodePtr = std::shared_ptr; +// class Creator : public MixedModeVisitor { +// public: +// IndexedGraph CreateGraph(const Expr& expr) { +// VisitExpr(expr); +// return std::move(graph_); +// } +// void Create(const Expr& expr) { VisitExpr(expr); } +// +// protected: +// void DispatchVisitExpr(const Expr& expr) override { +// MixedModeVisitor::DispatchVisitExpr(expr); +// graph_.node_map[expr] = std::make_shared(expr, index++, {}); +// l graph_.topological_order.push_back(expr); +// } +// IndexdecGraph graph_; +// size_t index_ = 0; +// }; +// class Annotator : public ExprFunctor { +// public: +// Annotator(const IndexedGraph& graph) : graph_(graph) {} +// Annotate() { +// for (const auto& node : graph_.topological_order) { +// ExprFunctor::VisitExpr(node.ref, nullptr); +// } +// return std::move(graph_); +// } +// +// VisitExpr(const Expr& expr, NodePtr parent) override { +// if (parent) { +// graph_.node_map[expr].outputs.push_back(parent); +// } +// } +// +// protected: +// void VisitExpr_(const VarNode* op, NodePtr parent) override { +// if (op->type_annotation.defined()) { +// this->VisitType(op->type_annotation); +// } +// } +// +// void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} +// +// void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} +// +// void VisitExpr_(const TupleNode* op, NodePtr parent) override { +// for (auto field : op->fields) { +// this->VisitExpr(field, GetRef(op)); +// } +// } +// +// void VisitExpr_(const FunctionNode* op, NodePtr parent) override { +// for (auto param : op->params) { +// this->VisitExpr(param, GetRef(op)); +// } +// +// this->VisitExpr(op->body, GetRef(op)); +// } +// +// void VisitExpr_(const CallNode* op, NodePtr parent) override { +// this->VisitExpr(op->op, GetRef(op)); +// +// for (auto ty_arg : op->type_args, NodePtr parent) override { +// this->VisitType(ty_arg, GetRef(op)); +// } +// +// for (auto arg : op->args) { +// this->VisitExpr(arg, GetRef(op)); +// } +// } +// +// void VisitExpr_(const LetNode* op, NodePtr parent) override { +// this->VisitExpr(op->value, GetRef(op)); +// this->VisitExpr(op->var, GetRef(op)); +// this->VisitExpr(op->body, GetRef(op)); +// } +// +// void VisitExpr_(const IfNode* op, NodePtr parent) override { +// this->VisitExpr(op->cond, GetRef(op)); +// this->VisitExpr(op->true_branch, GetRef(op)); +// this->VisitExpr(op->false_branch, GetRef(op)); +// } +// +// void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } +// +// void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { +// this->VisitExpr(op->tuple, GetRef(op)); +// } +// +// void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { +// this->VisitExpr(op->value, GetRef(op)); +// } +// +// void VisitExpr_(const RefReadNode* op, NodePtr parent) override { +// this->VisitExpr(op->ref, GetRef(op)); +// } +// +// void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { +// this->VisitExpr(op->ref, GetRef(op)); +// this->VisitExpr(op->value, GetRef(op)); +// } +// +// void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { +// for (const Type& t : op->inputs) { +// this->VisitType(t); +// } +// this->VisitType(op->belong_to); +// } +// +// void VisitExpr_(const MatchNode* op, NodePtr parent) override { +// this->VisitExpr(op->data, GetRef(op)); +// for (const Clause& c : op->clauses) { +// this->VisitClause(c, GetRef(op)); +// } +// } +// +// void VisitClause(const Clause& op, NodePtr parent) override { +// this->VisitPattern(op->lhs); +// this->VisitExpr(op->rhs, parent); +// } +// +// void VisitPattern(const Pattern& p) { return; } +// +// void VisitType(const Type& t) { return; } +// }; +// return Annotator(Creator().CreateGraph(expr)).Annotate(); +// } + +//IndexedGraph CreateIndexedGraph(const DFPattern&) { +//} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c7c98a2d40e6..e7b06711fa6c 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -384,116 +384,6 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } - - -// DFPatternMutator - -DFPattern DFPatternMutator::Mutate(const DFPattern& pattern) { return VisitDFPattern(pattern); } - -DFPattern DFPatternMutator::VisitDFPattern(const DFPattern& pattern) { - auto it = this->memo_.find(pattern); - if (it != this->memo_.end()) { - return it->second; - } else { - auto new_pattern = DFPatternFunctor::VisitDFPattern(pattern); - memo_[pattern] = new_pattern; - return new_pattern; - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const AltPatternNode* op) { - auto new_left = Mutate(op->left); - auto new_right = Mutate(op->right); - - if (new_left.same_as(op->left) && new_right.same_as(op->right)) { - return GetRef(op); - } else { - return AltPatternNode::make(new_left, new_right); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const AttrPatternNode* op) { - auto new_pattern = Mutate(op->pattern); - if (new_pattern.same_as(op->pattern)) { - return GetRef(op); - } else { - return AttrPatternNode::make(new_pattern, op->attrs); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const CallPatternNode* op) { - auto new_op = Mutate(op->op); - bool unchanged = op->op.same_as(new_op); - tvm::Array call_args; - for (auto arg : op->args) { - auto new_arg = Mutate(arg); - call_args.push_back(new_arg); - unchanged &= arg.same_as(new_arg); - } - if (unchanged) { - return GetRef(op); - } else { - return CallPatternNode::make(new_op, call_args, op->attrs, op->type_args); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const DominatorPatternNode* op) { - auto new_parent = Mutate(op->parent); - auto new_path = Mutate(op->path); - auto new_child = Mutate(op->child); - if (op->parent.same_as(new_child) && op->parent.same_as(new_child) && - op->parent.same_as(new_child)) { - return GetRef(op); - } else { - return DominatorPatternNode::make(new_parent, new_path, new_child); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const ExprPatternNode* op) { - return GetRef(op); -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TupleGetItemPatternNode* op) { - auto new_tuple = Mutate(op->tuple); - if (new_tuple.same_as(op->tuple)) { - return GetRef(op); - } else { - return TupleGetItemPatternNode::make(op->tuple, op->index); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TuplePatternNode* op) { - bool unchanged = true; - tvm::Array fields; - for (auto field : op->fields) { - auto new_field = Mutate(field); - fields.push_back(new_field); - unchanged &= field.same_as(new_field); - } - if (unchanged) { - return GetRef(op); - } else { - return TuplePatternNode::make(fields); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TypePatternNode* op) { - auto new_pattern = Mutate(op->pattern); - if (new_pattern.same_as(op->pattern)) { - return GetRef(op); - } else { - return TypePatternNode::make(new_pattern, op->type); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const VarPatternNode* op) { - return GetRef(op); -} - -DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { - return GetRef(op); -} - TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) { return DFPatternMatcher().Match(pattern, expr); }); diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 7d8fd0f97213..e97c2b4c28d2 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -274,7 +274,7 @@ def test_not_fuse_multi_diamond(): # Check assert not diamond.match(out) -def (pre, post): +def fuse_batchnorm(pre, post): def left_right_call(post): if isinstance(post.args[0], relay.Call): return (post.args[1], post.args[0]) From 09c8e51aabcbc178db1b0401f1938e513aaf38fb Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 8 Apr 2020 13:16:36 -0700 Subject: [PATCH 10/46] code complete forward graph creation --- include/tvm/relay/dataflow_functor.h | 17 ++ src/relay/ir/dataflow_functor.cc | 239 ++++++++++----------------- 2 files changed, 104 insertions(+), 152 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 73cccd15c3bb..73892846a71e 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -153,6 +153,23 @@ class DFPatternMutator : public DFPatternFunctor { std::unordered_map memo_; }; + +template +struct Node { + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + const T ref_; + const size_t index_; + std::vector>> outputs_; +}; + +template +struct IndexedGraph { + std::unordered_map>, ObjectHash, ObjectEqual> node_map_; + std::vector>> topological_order_; +}; + +IndexedGraph CreateIndexedGraph(const Expr& expr); +IndexedGraph CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index 1d29faafa18e..a7837f1611a8 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -185,24 +185,9 @@ DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { return GetRef(op); } -//IndexedGraph - -template -struct Node { - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} - const T ref_; - const size_t index_; - std::vector>> outputs_; -}; - -template -struct IndexedGraph { - std::unordered_map>, ObjectHash, ObjectEqual> node_map_; - std::vector>> topological_order_; -}; - -IndexedGraph -CreateIndexedGraph(const Expr& expr) { +// IndexedGraph + +IndexedGraph CreateIndexedGraph(const Expr& expr) { using NodePtr = std::shared_ptr>; class Creator : public MixedModeVisitor { public: @@ -210,7 +195,6 @@ CreateIndexedGraph(const Expr& expr) { VisitExpr(expr); return std::move(graph_); } - void Create(const Expr& expr) { VisitExpr(expr); } protected: void VisitLeaf(const Expr& expr) override { @@ -332,139 +316,90 @@ CreateIndexedGraph(const Expr& expr) { }; return Annotator(Creator().CreateGraph(expr)).Annotate(); } -// -// IndexedGraph -// CreateIndexedGraph(const Expr& expr) { -// using NodePtr = std::shared_ptr; -// class Creator : public MixedModeVisitor { -// public: -// IndexedGraph CreateGraph(const Expr& expr) { -// VisitExpr(expr); -// return std::move(graph_); -// } -// void Create(const Expr& expr) { VisitExpr(expr); } -// -// protected: -// void DispatchVisitExpr(const Expr& expr) override { -// MixedModeVisitor::DispatchVisitExpr(expr); -// graph_.node_map[expr] = std::make_shared(expr, index++, {}); -// l graph_.topological_order.push_back(expr); -// } -// IndexdecGraph graph_; -// size_t index_ = 0; -// }; -// class Annotator : public ExprFunctor { -// public: -// Annotator(const IndexedGraph& graph) : graph_(graph) {} -// Annotate() { -// for (const auto& node : graph_.topological_order) { -// ExprFunctor::VisitExpr(node.ref, nullptr); -// } -// return std::move(graph_); -// } -// -// VisitExpr(const Expr& expr, NodePtr parent) override { -// if (parent) { -// graph_.node_map[expr].outputs.push_back(parent); -// } -// } -// -// protected: -// void VisitExpr_(const VarNode* op, NodePtr parent) override { -// if (op->type_annotation.defined()) { -// this->VisitType(op->type_annotation); -// } -// } -// -// void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} -// -// void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} -// -// void VisitExpr_(const TupleNode* op, NodePtr parent) override { -// for (auto field : op->fields) { -// this->VisitExpr(field, GetRef(op)); -// } -// } -// -// void VisitExpr_(const FunctionNode* op, NodePtr parent) override { -// for (auto param : op->params) { -// this->VisitExpr(param, GetRef(op)); -// } -// -// this->VisitExpr(op->body, GetRef(op)); -// } -// -// void VisitExpr_(const CallNode* op, NodePtr parent) override { -// this->VisitExpr(op->op, GetRef(op)); -// -// for (auto ty_arg : op->type_args, NodePtr parent) override { -// this->VisitType(ty_arg, GetRef(op)); -// } -// -// for (auto arg : op->args) { -// this->VisitExpr(arg, GetRef(op)); -// } -// } -// -// void VisitExpr_(const LetNode* op, NodePtr parent) override { -// this->VisitExpr(op->value, GetRef(op)); -// this->VisitExpr(op->var, GetRef(op)); -// this->VisitExpr(op->body, GetRef(op)); -// } -// -// void VisitExpr_(const IfNode* op, NodePtr parent) override { -// this->VisitExpr(op->cond, GetRef(op)); -// this->VisitExpr(op->true_branch, GetRef(op)); -// this->VisitExpr(op->false_branch, GetRef(op)); -// } -// -// void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } -// -// void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { -// this->VisitExpr(op->tuple, GetRef(op)); -// } -// -// void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { -// this->VisitExpr(op->value, GetRef(op)); -// } -// -// void VisitExpr_(const RefReadNode* op, NodePtr parent) override { -// this->VisitExpr(op->ref, GetRef(op)); -// } -// -// void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { -// this->VisitExpr(op->ref, GetRef(op)); -// this->VisitExpr(op->value, GetRef(op)); -// } -// -// void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { -// for (const Type& t : op->inputs) { -// this->VisitType(t); -// } -// this->VisitType(op->belong_to); -// } -// -// void VisitExpr_(const MatchNode* op, NodePtr parent) override { -// this->VisitExpr(op->data, GetRef(op)); -// for (const Clause& c : op->clauses) { -// this->VisitClause(c, GetRef(op)); -// } -// } -// -// void VisitClause(const Clause& op, NodePtr parent) override { -// this->VisitPattern(op->lhs); -// this->VisitExpr(op->rhs, parent); -// } -// -// void VisitPattern(const Pattern& p) { return; } -// -// void VisitType(const Type& t) { return; } -// }; -// return Annotator(Creator().CreateGraph(expr)).Annotate(); -// } - -//IndexedGraph CreateIndexedGraph(const DFPattern&) { -//} + +IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { + using NodePtr = std::shared_ptr>; + class Creator : public DFPatternVisitor { + public: + IndexedGraph CreateGraph(const DFPattern& pattern) { + VisitDFPattern(pattern); + return std::move(graph_); + } + + protected: + void VisitDFPattern(const DFPattern& pattern) override { + DFPatternVisitor::VisitDFPattern(pattern); + auto node = std::make_shared>(pattern, index_++); + graph_.node_map_[pattern] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + class Annotator : public DFPatternFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + for (const auto& node : graph_.topological_order_) { + DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + } + return std::move(graph_); + } + + void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { + if (parent) { + graph_.node_map_[pattern]->outputs_.push_back(parent); + } + } + + protected: + IndexedGraph graph_; + void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + for (auto arg : op->args) { + VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + } + } + void VisitDFPattern_(const DominatorPatternNode* op, + NodePtr parent) override { + VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const TupleGetItemPatternNode* op, + NodePtr parent) override { + VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + for (auto field : op->fields) { + VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override { + } + }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); +} } // namespace relay } // namespace tvm From 4cb0fc956f4c1a1646a68ad3eb4761248e137827 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 8 Apr 2020 16:28:37 -0700 Subject: [PATCH 11/46] compiling dominator tree --- include/tvm/relay/dataflow_functor.h | 75 ++++++++++++++++++++++------ src/relay/ir/dataflow_functor.cc | 24 ++++++--- 2 files changed, 79 insertions(+), 20 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 73892846a71e..106c074e91ec 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -153,24 +153,71 @@ class DFPatternMutator : public DFPatternFunctor { std::unordered_map memo_; }; - template -struct Node { - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} - const T ref_; - const size_t index_; - std::vector>> outputs_; -}; +class IndexedGraph { + public: + struct Node; + struct Edge { + Edge(const std::shared_ptr& sink) : sink_(sink) {} + std::shared_ptr sink_; + }; + struct Node { + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + const T ref_; + const size_t index_; + + bool is_external_ = false; + std::vector> outputs_; + + size_t depth_; + std::shared_ptr dominator_parent_; + }; + void PostDom() { + for (size_t i = topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + auto current = topological_order_[index]; + if (current->is_external_) { + current->depth_ = 1; + current->dominator_parent_ = nullptr; + } else { + auto parent = LeastCommonAncestor(current->outputs_); + current->depth_ = parent ? parent->depth_ + 1 : 1; + current->dominator_parent_ = parent; + } + } + } + std::unordered_map, ObjectHash, ObjectEqual> node_map_; + std::vector> topological_order_; -template -struct IndexedGraph { - std::unordered_map>, ObjectHash, ObjectEqual> node_map_; - std::vector>> topological_order_; + protected: + std::shared_ptr LeastCommonAncestor(const std::vector>& outputs) { + if (outputs.size() == 0) { + return nullptr; + } + auto parent = outputs.at(0)->sink_; + for (size_t i = 1; i < outputs.size(); ++i) { + parent = LeastCommonAncestor(parent, outputs.at(i)->sink_); + } + return parent; + } + std::shared_ptr LeastCommonAncestor(std::shared_ptr lhs, std::shared_ptr rhs) { + if (lhs == nullptr || rhs == nullptr) { + return nullptr; + } + while (lhs != rhs) { + if (lhs->depth_ < rhs->depth_) { + rhs = rhs->dominator_parent_; + } else if (lhs->depth_ > rhs->depth_) { + lhs = lhs->dominator_parent_; + } else { + rhs = rhs->dominator_parent_; + lhs = lhs->dominator_parent_; + } + } + return lhs; + } }; -IndexedGraph CreateIndexedGraph(const Expr& expr); -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm - #endif // TVM_RELAY_DATAFLOW_FUNCTOR_H_ diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index a7837f1611a8..b428c12bdb0c 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -185,21 +185,24 @@ DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { return GetRef(op); } + + // IndexedGraph IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr>; + using NodePtr = std::shared_ptr::Node>; class Creator : public MixedModeVisitor { public: IndexedGraph CreateGraph(const Expr& expr) { VisitExpr(expr); + graph_.node_map_[expr]->is_external_ = true; return std::move(graph_); } protected: void VisitLeaf(const Expr& expr) override { MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared>(expr, index_++); + auto node = std::make_shared::Node>(expr, index_++); graph_.node_map_[expr] = node; graph_.topological_order_.push_back(node); } @@ -213,12 +216,15 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { for (const auto& node : graph_.topological_order_) { ExprFunctor::VisitExpr(node->ref_, nullptr); } + graph_.PostDom(); return std::move(graph_); } void VisitExpr(const Expr& expr, NodePtr parent) override { + auto current = graph_.node_map_[expr]; if (parent) { - graph_.node_map_[expr]->outputs_.push_back(parent); + auto edge = std::make_shared::Edge>(parent); + current->outputs_.push_back(edge); } } @@ -318,18 +324,19 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { } IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr>; + using NodePtr = std::shared_ptr::Node>; class Creator : public DFPatternVisitor { public: IndexedGraph CreateGraph(const DFPattern& pattern) { VisitDFPattern(pattern); + graph_.node_map_[pattern]->is_external_ = true; return std::move(graph_); } protected: void VisitDFPattern(const DFPattern& pattern) override { DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared>(pattern, index_++); + auto node = std::make_shared::Node>(pattern, index_++); graph_.node_map_[pattern] = node; graph_.topological_order_.push_back(node); } @@ -343,12 +350,15 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { for (const auto& node : graph_.topological_order_) { DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); } + graph_.PostDom(); return std::move(graph_); } void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { + auto current = graph_.node_map_[pattern]; if (parent) { - graph_.node_map_[pattern]->outputs_.push_back(parent); + auto edge = std::make_shared::Edge>(parent); + current->outputs_.push_back(edge); } } @@ -401,5 +411,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { return Annotator(Creator().CreateGraph(pattern)).Annotate(); } + + } // namespace relay } // namespace tvm From cec1927bb59e2aaa5f6679f7d46c563fc4fc4c46 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 10 Apr 2020 08:54:42 -0700 Subject: [PATCH 12/46] partial dominator --- include/tvm/relay/dataflow_functor.h | 4 + include/tvm/relay/dataflow_matcher.h | 106 +-------------------------- src/relay/ir/dataflow_matcher.cc | 49 +++---------- 3 files changed, 15 insertions(+), 144 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 106c074e91ec..e2457ad6ef43 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -106,6 +106,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); @@ -218,6 +219,9 @@ class IndexedGraph { } }; +IndexedGraph CreateIndexedGraph(const Expr& expr); +IndexedGraph CreateIndexedGraph(const DFPattern& pattern); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_DATAFLOW_FUNCTOR_H_ diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index d6943929a5b1..b81e5996b8db 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ #define TVM_RELAY_DATAFLOW_MATCHER_H_ +#include #include #include #include @@ -31,111 +32,6 @@ namespace tvm { namespace relay { -/*! - * \brief A dynamical functor that dispatches on in the first DFPattern argument. - * - * \tparam FType function signiture - * This type is only defined for FType with function signature R(const DFPattern&, - * Args...) - */ -template -class DFPatternFunctor; - -// functions to be overriden. -#define DFPATTERN_FUNCTOR_DEFAULT \ - { return VisitDFPatternDefault_(op, std::forward(args)...); } - -#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ - }); - -template -class DFPatternFunctor { - private: - using TSelf = DFPatternFunctor; - using FType = tvm::NodeFunctor; - - public: - /*! \brief the result type of this functor */ - using result_type = R; - /*! \brief virtual destructor */ - virtual ~DFPatternFunctor() {} - /*! - * \brief Same as call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - R operator()(const DFPattern& n, Args... args) { - return VisitDFPattern(n, std::forward(args)...); - } - /*! - * \brief The functor call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - virtual R VisitDFPattern(const DFPattern& n, Args... args) { - CHECK(n.defined()); - static FType vtable = InitVTable(); - return vtable(n, this, std::forward(args)...); - } - // Functions that can be overriden by subclass - virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, - Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPatternDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); - throw; - } - - private: - // initialize the vtable. - static FType InitVTable() { - FType vtable; - // Set dispatch - RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); - return vtable; - } -}; - -class DFPatternMutator : public DFPatternFunctor { - public: - virtual DFPattern Mutate(const DFPattern& pattern); - DFPattern VisitDFPattern(const DFPattern& pattern) override; - DFPattern VisitDFPattern_(const AltPatternNode* op) override; - DFPattern VisitDFPattern_(const AttrPatternNode* op) override; - DFPattern VisitDFPattern_(const CallPatternNode* op) override; - DFPattern VisitDFPattern_(const DominatorPatternNode* op) override; - DFPattern VisitDFPattern_(const ExprPatternNode* op) override; - DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override; - DFPattern VisitDFPattern_(const TuplePatternNode* op) override; - DFPattern VisitDFPattern_(const TypePatternNode* op) override; - DFPattern VisitDFPattern_(const VarPatternNode* op) override; - DFPattern VisitDFPattern_(const WildcardPatternNode* op) override; - - protected: - std::unordered_map memo_; -}; - class DFPatternCallback; /*! * \brief Base type of all dataflow pattern callbacks. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e7b06711fa6c..9f880aba146d 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -54,37 +54,6 @@ class DFPatternMatcher : public DFPatternFunctor matched_nodes_; }; -class DominatorMatcher : public DFPatternMatcher { - public: - DominatorMatcher(const DominatorPatternNode* dominator) : dominator_(dominator) {} - bool Dominates(const Expr& expr) { - found_child = DFPatternMatcher::VisitDFPattern(dominator_->child, expr); - if (found_child) { - return false; - } - return false; - } - - const std::unordered_map& GetMemo() { return memo_; } - const std::vector GetMatched() { return matched_nodes_; } - protected: - bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override { - std::cout << "visiting " << pattern << "\n\t -> " << expr << std::endl; - if (DFPatternMatcher::VisitDFPattern(pattern, expr)) { - return true; - } else if (found_child) { - if (DFPatternMatcher::VisitDFPattern(dominator_->parent, expr)) { - return true; - } else { - return DFPatternMatcher::VisitDFPattern(dominator_->path, expr); - } - } - return false; - } - const DominatorPatternNode* dominator_; - bool found_child = false; -}; - bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); matched_nodes_.clear(); @@ -320,15 +289,17 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { - DominatorMatcher visitor(op); - if (visitor.Dominates(expr)) { - const auto new_memo = visitor.GetMemo(); - const auto new_matched = visitor.GetMatched(); - for (const auto &pattern : new_matched) { - matched_nodes_.push_back(pattern); - memo_[pattern] = new_memo.at(pattern); + //auto watermark = matched_nodes_.size(); + if (VisitDFPattern(op->child, expr)) { + auto graph = CreateIndexedGraph(expr); + std::vector dominated; + for (auto node : graph.topological_order_) { + if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) { + dominated.push_back(node->ref_); + std::cout << node->ref_ << std::endl; + } } - return true; + return false; } return false; } From 455b726bf02a206330c126736d9ac4b12b49f86b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 10 Apr 2020 10:21:05 -0700 Subject: [PATCH 13/46] functioning dominator matcher? --- src/relay/ir/dataflow_matcher.cc | 32 +++++++++++++++++++++------ tests/python/relay/test_df_pattern.py | 9 ++++---- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 9f880aba146d..5d30a64b6aad 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -289,17 +289,35 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { - //auto watermark = matched_nodes_.size(); + auto watermark = matched_nodes_.size(); if (VisitDFPattern(op->child, expr)) { - auto graph = CreateIndexedGraph(expr); - std::vector dominated; - for (auto node : graph.topological_order_) { + bool matches = true; + std::unordered_set dominated_exprs; + auto child_graph = CreateIndexedGraph(op->child); + for (auto node : child_graph.topological_order_) { + if (node->ref_.as()) { + continue; + } + if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) { + dominated_exprs.insert(memo_[node->ref_]); + } + } + ClearMap(watermark); + auto expr_graph = CreateIndexedGraph(expr); + for (auto node : expr_graph.topological_order_) { if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) { - dominated.push_back(node->ref_); - std::cout << node->ref_ << std::endl; + if (dominated_exprs.count(node->ref_) == 0) { + bool node_matches = VisitDFPattern(op->parent, node->ref_); + ClearMap(watermark); + matches = node_matches || VisitDFPattern(op->path, node->ref_); + ClearMap(watermark); + if (!matches) { + return false; + } + } } } - return false; + return matches; } return false; } diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index e97c2b4c28d2..6bba523b1a77 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -22,7 +22,8 @@ # NB: 1 corresponds to the C++ enum that specicfies this # we loose the type safety due to the Python/C++ calling # convention. -K_ELEMWISE = 1 +K_ELEMWISE = 0 +K_BROADCAST = 1 ## NODE TESTS def test_expr_pattern(): @@ -155,7 +156,7 @@ def test_no_match_type(): assert not ty_pat.match(x) def test_match_attr(): - op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) + op = is_op('add').has_attr("TOpPattern", K_BROADCAST) op_pat = op(wildcard(), wildcard()) x = relay.var('x') y = relay.var('y') @@ -230,9 +231,9 @@ def test_match_fake_diamond(): def test_match_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_elemwise = wildcard().has_attr("TOpPattern", K_ELEMWISE) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) reduction = is_op('add')(wildcard(), wildcard()) - diamond = dominates(is_conv2d, is_elemwise, reduction) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) # Expr inp = relay.var('input') From e501daf71181b5169a3621c9dc97cc85d3d109f8 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 10 Apr 2020 15:20:39 -0700 Subject: [PATCH 14/46] clean up associative/commutative matching --- src/relay/ir/dataflow_matcher.cc | 162 ++++++++++---------------- tests/python/relay/test_df_pattern.py | 13 +-- 2 files changed, 66 insertions(+), 109 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 5d30a64b6aad..c6c166b5711b 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -129,8 +129,7 @@ Array reverse(const Array args) { } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { - auto watermark = matched_nodes_.size(); - + // utilities auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { if (auto* expr_pattern = op->op.as()) { @@ -139,11 +138,32 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } return nullptr; }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + // logic + auto watermark = matched_nodes_.size(); if (const auto* call_node = expr.as()) { auto matches_op = VisitDFPattern(op->op, call_node->op); if (matches_op) { auto watermark2 = matched_nodes_.size(); + auto match_args = [this, &watermark2](const Array pattern_args, const Array expr_args) { bool matches = true; size_t i = 0; @@ -161,10 +181,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return matches; }; + // Standard case if (match_args(op->args, call_node->args)) { return true; } - if (auto* op_node = get_op_node(op)) { + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { if ((op_node->name == "add") || (op_node->name == "multiply")) { if (match_args(reverse(op->args), call_node->args)) { return true; @@ -173,111 +195,47 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } } else { ClearMap(watermark); - // TODO(mbrookhart): This is nasty. Find a cleaner way to do this - if (const OpNode* op_node = get_op_node(op)) { - if (op_node->name == "divide") { - if (auto* arg_node = op->args[0].as()) { - if (const OpNode* arg_op = get_op_node(arg_node)) { - if (arg_op->name == "multiply") { - auto associate_div_mul = [this, &op, &arg_node, &expr, &watermark]() { - auto div1 = CallPatternNode::make(op->op, {arg_node->args[1], op->args[1]}, - op->attrs, op->type_args); - auto mul1 = CallPatternNode::make(arg_node->op, {arg_node->args[0], div1}, - arg_node->attrs, arg_node->type_args); - auto div2 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]}, - op->attrs, op->type_args); - auto mul2 = CallPatternNode::make(arg_node->op, {arg_node->args[1], div2}, - arg_node->attrs, arg_node->type_args); - auto out = VisitDFPattern(mul1, expr); - if (!out) { + // associate divide/multiply + if (is_pattern_op(op, "divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "multiply")) { + if (is_expr_op(expr, "multiply")) { + if (is_expr_op(call_node->args[0], "divide") || + is_expr_op(call_node->args[1], "divide")) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]}, + op->attrs, op->type_args); + auto mul = + CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, + arg_node->attrs, arg_node->type_args); + out = VisitDFPattern(mul, expr); + if (out) { + return out; + } else { ClearMap(watermark); - out = VisitDFPattern(mul2, expr); - } - return out; - }; - - if (const OpNode* expr_op_node = call_node->op.as()) { - if (expr_op_node->name == "multiply") { - if (auto* input_call_node = call_node->args[0].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "divide") { - return associate_div_mul(); - } - } - } - if (auto* input_call_node = call_node->args[1].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "divide") { - return associate_div_mul(); - } - } - } } } + return out; } } } - } else if (op_node->name == "multiply") { - if (auto* arg_node = op->args[0].as()) { - if (const OpNode* arg_op = get_op_node(arg_node)) { - if (arg_op->name == "divide") { - auto associate_mul_div = [this, &op, &arg_node, &expr]() { - auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]}, - op->attrs, op->type_args); - auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]}, - arg_node->attrs, arg_node->type_args); - return VisitDFPattern(div1, expr); - }; - - if (const OpNode* expr_op_node = call_node->op.as()) { - if (expr_op_node->name == "divide") { - if (auto* input_call_node = call_node->args[0].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "multiply") { - return associate_mul_div(); - } - } - } - if (auto* input_call_node = call_node->args[1].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "multiply") { - return associate_mul_div(); - } - } - } - } - } - } - } - } - if (auto* arg_node = op->args[1].as()) { - if (const OpNode* arg_op = get_op_node(arg_node)) { - if (arg_op->name == "divide") { - auto associate_mul_div = [this, &op, &arg_node, &expr]() { - auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[0]}, - op->attrs, op->type_args); - auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]}, - arg_node->attrs, arg_node->type_args); - return VisitDFPattern(div1, expr); - }; - - if (const OpNode* expr_op_node = call_node->op.as()) { - if (expr_op_node->name == "divide") { - if (auto* input_call_node = call_node->args[0].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "multiply") { - return associate_mul_div(); - } - } - } - if (auto* input_call_node = call_node->args[1].as()) { - if (const OpNode* input_op_node = input_call_node->op.as()) { - if (input_op_node->name == "multiply") { - return associate_mul_div(); - } - } - } - } + } + } + if (is_pattern_op(op, "multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "divide")) { + if (is_expr_op(expr, "divide")) { + if (is_expr_op(call_node->args[0], "multiply") || + is_expr_op(call_node->args[1], "multiply")) { + auto mul = + CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, + op->attrs, op->type_args); + auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]}, + arg_node->attrs, arg_node->type_args); + return VisitDFPattern(div, expr); } } } diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 6bba523b1a77..fd712680ff63 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -379,14 +379,13 @@ def test_fuse_batchnorm_commutation(): out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) - # associate multiply/divide - BN = (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) * gamma + beta + # associate divide/multiply + BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) - # associate divide/multiply - BN_pattern = wildcard() * ((wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard())) + wildcard() - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + # associate multiply/divide + BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) @@ -412,5 +411,5 @@ def test_fuse_batchnorm_commutation(): #test_no_fuse_batchnorm() #test_fuse_double_batchnorm() #test_partial_fuse_double_batchnorm() - #test_fuse_batchnorm_commutation() - test_match_dominator() + test_fuse_batchnorm_commutation() + #test_match_dominator() From 887dd7035b9b8fab5007891abf69f43b04143083 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 10 Apr 2020 16:37:21 -0700 Subject: [PATCH 15/46] add algebraic simplifier --- src/relay/ir/dataflow_matcher.cc | 2 +- tests/python/relay/test_df_pattern.py | 116 ++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c6c166b5711b..c46871b0df8c 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -280,7 +280,7 @@ bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Exp return false; } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { - return op->expr == expr; + return StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { bool matches = false; diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index fd712680ff63..10da8b587d24 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -247,6 +247,43 @@ def test_match_dominator(): # Check assert diamond.match(out) + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +def test_not_match_dominator(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + def test_rewrite(): x = relay.var('x') y = relay.var('y') @@ -389,6 +426,85 @@ def test_fuse_batchnorm_commutation(): out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) +def algebraic_simplify(expr): + pattern_callbacks = [] + + def elwise_zero_callback(pre, post): + if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | + tvm.ir.structural_equal(post.args[0], relay.const(0.0))): + return post.args[1] + else: + return post.args[0] + + def elwise_one_callback(pre, post): + if (tvm.ir.structural_equal(post.args[0], relay.const(1)) | + tvm.ir.structural_equal(post.args[0], relay.const(1.0))): + return post.args[1] + else: + return post.args[0] + + def return_zero_callback(pre, post): + if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | + tvm.ir.structural_equal(post.args[0], relay.const(0.0))): + return post.args[0] + else: + return post.args[1] + + zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) + one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) + add_pattern = wildcard() + zero + pattern_callbacks.append(DFPatternCallback(add_pattern, elwise_zero_callback)) + + sub_pattern = wildcard() - zero + pattern_callbacks.append(DFPatternCallback(sub_pattern, elwise_zero_callback)) + + mul_pattern = wildcard() * one + pattern_callbacks.append(DFPatternCallback(mul_pattern, elwise_one_callback)) + + mul_zero_pattern = wildcard() * zero + pattern_callbacks.append(DFPatternCallback(mul_zero_pattern, return_zero_callback)) + + div_pattern = wildcard() / one + pattern_callbacks.append(DFPatternCallback(div_pattern, elwise_one_callback)) + + zero_div_pattern = zero / wildcard() + pattern_callbacks.append(DFPatternCallback(zero_div_pattern, return_zero_callback)) + + return rewrite(pattern_callbacks, expr); + +def test_algebraic_simplify(): + x = relay.Var('x') + y = relay.Var('y') + + print(x + relay.const(0)) + + one = relay.const(1) + zero = relay.const(0) + onef = relay.const(1.0) + zerof = relay.const(0.0) + + assert algebraic_simplify(x + zero) == x + assert algebraic_simplify(x + zerof) == x + assert algebraic_simplify(zero + x) == x + assert algebraic_simplify(zerof + x) == x + + assert algebraic_simplify(x - zero) == x + assert algebraic_simplify(x - zerof) == x + + assert algebraic_simplify(x * one) == x + assert algebraic_simplify(x * onef) == x + assert algebraic_simplify(one * x) == x + assert algebraic_simplify(onef * x) == x + assert algebraic_simplify(x * zero) == zero + assert algebraic_simplify(x * zerof) == zerof + + assert algebraic_simplify(x / one) == x + assert algebraic_simplify(x / onef) == x + assert algebraic_simplify(zero / x) == zero + assert algebraic_simplify(zerof / x) == zerof + + assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) + if __name__ == "__main__": #test_match_op() #test_no_match_op() From 05c5da2b5d6be055057b9229f1afb7c042ed5d4f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 13 Apr 2020 14:20:05 -0700 Subject: [PATCH 16/46] add more tests --- include/tvm/relay/dataflow_functor.h | 5 ++++- python/tvm/relay/df_pattern/__init__.py | 4 ++-- src/relay/ir/dataflow_matcher.cc | 14 ++++++-------- tests/python/relay/test_df_pattern.py | 20 +++++++++++++++++++- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index e2457ad6ef43..fb2af5546759 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -25,8 +25,11 @@ #define TVM_RELAY_DATAFLOW_FUNCTOR_H_ #include +#include #include +#include #include +#include namespace tvm { namespace relay { @@ -159,7 +162,7 @@ class IndexedGraph { public: struct Node; struct Edge { - Edge(const std::shared_ptr& sink) : sink_(sink) {} + explicit Edge(const std::shared_ptr& sink) : sink_(sink) {} std::shared_ptr sink_; }; struct Node { diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 30a5450ae947..f183f67e7f2e 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -231,7 +231,7 @@ class DominatorPattern(DFPattern): Parameters ---------- parent: tvm.relay.df_pattern.DFPattern - The root of domination + The root of domination path: tvm.relay.df_pattern.DFPattern The fuzzy path pattern between parent and child child: tvm.relay.df_pattern.DFPattern @@ -240,7 +240,7 @@ class DominatorPattern(DFPattern): def __init__(self, parent, path, child): self.__init_handle_by_constructor__( ffi.DominatorPattern, parent, path, child) - + class DFPatternCallback(Object): def __init__(self, pattern, callback): self.__init_handle_by_constructor__( diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c46871b0df8c..e0cb26d0d8c8 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -23,8 +23,8 @@ */ #include -#include #include +#include #include namespace tvm { @@ -164,7 +164,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (matches_op) { auto watermark2 = matched_nodes_.size(); - auto match_args = [this, &watermark2](const Array pattern_args, const Array expr_args) { + auto match_args = [this, &watermark2](const Array pattern_args, + const Array expr_args) { bool matches = true; size_t i = 0; if (pattern_args.size() == expr_args.size()) { @@ -347,14 +348,12 @@ DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc func TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") -.set_body_typed(DFPatternCallbackNode::make); + .set_body_typed(DFPatternCallbackNode::make); class PatternRewriter : protected MixedModeMutator { public: PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} - Expr Rewrite(const Expr& pre) { - return this->VisitExpr(pre); - } + Expr Rewrite(const Expr& pre) { return this->VisitExpr(pre); } protected: Expr DispatchVisitExpr(const Expr& pre) override { @@ -377,8 +376,7 @@ Expr RewritePatterns(Array callbacks, Expr expr) { return PatternRewriter(callbacks).Rewrite(expr); } -TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite") -.set_body_typed(RewritePatterns); +TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 10da8b587d24..40dd4ef27463 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -284,6 +284,23 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(inp) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + def test_rewrite(): x = relay.var('x') y = relay.var('y') @@ -527,5 +544,6 @@ def test_algebraic_simplify(): #test_no_fuse_batchnorm() #test_fuse_double_batchnorm() #test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() + #test_fuse_batchnorm_commutation() #test_match_dominator() + test_not_match_dominator() From 686c43cd01b0307d92ef5daa6b32e9736e5b9369 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 09:29:31 -0700 Subject: [PATCH 17/46] document python API --- python/tvm/relay/df_pattern/__init__.py | 273 +++++++++++++++++++----- 1 file changed, 221 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index f183f67e7f2e..d51b3bc5af2d 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -21,6 +21,8 @@ from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi +from tvm.relay import Expr + def register_df_node(type_key=None): """Register a Relay node type. @@ -35,12 +37,11 @@ def register_df_node(type_key=None): "relay.df_pattern." + type_key.__name__)(type_key) return tvm_ffi.register_object(type_key) -class DFPattern(Node): - """Base class of all primitive expressions. - PrimExpr is used in the low-level code - optimizations and integer analysis. +class DFPattern(Node): + """Base class of all Patterns. """ + def __call__(self, *args): return CallPattern(self, list(args)) @@ -59,13 +60,154 @@ def __mul__(self, other): def __truediv__(self, other): return is_op("divide")(self, other) - def has_attr(self, attr_name, attr_value): - attrs = make_node("DictAttrs", **{attr_name:attr_value}) + def has_attr(self, attr_name: str, attr_value): + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attr_name: str + The name of the attribute to match + attr_value: Any + The value of the attribute to match + """ + attrs = make_node("DictAttrs", **{attr_name: attr_value}) return AttrPattern(self, attrs) - def match(self, expr): + def has_type(self, ttype): + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.relay.Type + The type to match + """ + return has_type(ttype, self) + + def match(self, expr: Expr) -> bool: + """ + Match this pattern to an expression + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + """ return match(self, expr) + def dominates(self, parent, path=None): + """ + Create a dominator for this partern + + Parameters + ---------- + parent: tvm.relay.df_pattern.DFPattern + The parent pattern this pattern dominates. + path: tvm.relay.df_pattern.DFPattern + The fuzzy path pattern. + """ + if path == None: + path = wildcard() + return DominatorPattern(parent, path, self) + + +def is_input(name: str = "") -> DFPattern: + """ + Syntatic sugar for creating an optionally named VarPattern + + Parameters + ---------- + name: str + The name of the input pattern to match + """ + return VarPattern(name) + + +def is_op(op_name: str) -> DFPattern: + """ + Syntatic sugar for creating an operator ExprPattern + + Parameters + ---------- + op_name: String + The name of the relay op + """ + op = get(op_name) + return ExprPattern(op) + + +def wildcard() -> DFPattern: + """ + Syntatic sugar for creating a WildcardPattern + """ + return WildcardPattern() + + +def has_type(ttype, pattern: DFPattern = None) -> DFPattern: + """ + Syntatic sugar for creating a TypePattern + + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The pattern that needs type annotation + + ttype: tvm.relay.Type + The type to match + """ + if pattern is None: + pattern = wildcard() + return TypePattern(pattern, ttype) + + +def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The input pattern. + + attrs: tvm.Attrs + The attributes to match + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attr_name, attr_value) + + +def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern: + """ + Syntatic sugar for creating an Dominator pattern + + Parameters + ---------- + parent: tvm.relay.df_pattern.DFPattern + The parent pattern. + path: tvm.relay.df_pattern.DFPattern + The fuzzy path pattern. + child: tvm.relay.df_pattern.DFPattern + The child pattern. + """ + return DominatorPattern(parent, path, child) + + +def match(pattern: DFPattern, expr: Expr) -> bool: + """ + Match a pattern to an expression + + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The input pattern. + expr : tvm.relay.Expr + The expression to match. + """ + return ffi.match(pattern, expr) + + @register_df_node class ExprPattern(DFPattern): """A pattern which matches a constant expression. @@ -75,9 +217,11 @@ class ExprPattern(DFPattern): expr : tvm.relay.Expr The expression to match. """ - def __init__(self, expr): + + def __init__(self, expr: Expr): self.__init_handle_by_constructor__(ffi.ExprPattern, expr) + @register_df_node class VarPattern(DFPattern): """A local variable in Relay. @@ -95,16 +239,11 @@ class VarPattern(DFPattern): type_annotation: tvm.relay.Type, optional The type annotation on the variable. """ - def __init__(self, name_hint, type_annotation=None): + + def __init__(self, name_hint: str, type_annotation=None): self.__init_handle_by_constructor__( ffi.VarPattern, name_hint, type_annotation) -# @property -# def name_hint(self): -# """Get name hint of the current var.""" -# name = self.name -# return name - @register_df_node class CallPattern(DFPattern): @@ -125,12 +264,14 @@ class CallPattern(DFPattern): The additional type arguments, this is only used in advanced usecase of template functions. """ + def __init__(self, op, args, attrs=None, type_args=None): if not type_args: type_args = [] self.__init_handle_by_constructor__( ffi.CallPattern, op, args, attrs, type_args) + @register_df_node class TuplePattern(DFPattern): """A patern matching a Relay Tuple. @@ -140,6 +281,7 @@ class TuplePattern(DFPattern): fields : List[tvm.relay.df_pattern.DFPattern] The fields in the tuple. """ + def __init__(self, fields): self.__init_handle_by_constructor__(ffi.TuplePattern, fields) @@ -154,6 +296,7 @@ def __len__(self): def astype(self, _): raise TypeError("astype cannot be used on TuplePattern") + @register_df_node class TupleGetItemPattern(DFPattern): """Get index-th item from a TuplePattern. @@ -166,10 +309,12 @@ class TupleGetItemPattern(DFPattern): index: int The index. """ - def __init__(self, tuple_value, index): + + def __init__(self, tuple_value: DFPattern, index): self.__init_handle_by_constructor__( ffi.TupleGetItemPattern, tuple_value, index) + @register_df_node class AltPattern(DFPattern): """Create a Pattern that can match one of two conditions @@ -181,17 +326,21 @@ class AltPattern(DFPattern): right: tvm.relay.df_pattern.DFPattern One possible matching Pattern """ - def __init__(self, tuple_value, index): + + def __init__(self, left: DFPattern, right: DFPattern): self.__init_handle_by_constructor__( - ffi.AltPattern, tuple_value, index) + ffi.AltPattern, left, right) + @register_df_node class WildcardPattern(DFPattern): """A pattern which matches anything. """ + def __init__(self): self.__init_handle_by_constructor__(ffi.WildcardPattern) + @register_df_node class TypePattern(DFPattern): """Get index-th item from a TuplePattern. @@ -199,80 +348,100 @@ class TypePattern(DFPattern): Parameters ---------- pattern: tvm.relay.df_pattern.DFPattern - The input tuple expression. + The input pattern that needs type annotation ttype: tvm.relay.Type The type to match """ - def __init__(self, pattern, ttype): + + def __init__(self, pattern: DFPattern, ttype): self.__init_handle_by_constructor__( ffi.TypePattern, pattern, ttype) + @register_df_node class AttrPattern(DFPattern): - """Get index-th item from a TuplePattern. + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes Parameters ---------- pattern: tvm.relay.df_pattern.DFPattern - The input tuple expression. + The input pattern. attrs: tvm.Attrs The attributes to match """ - def __init__(self, pattern, attrs): + + def __init__(self, pattern: DFPattern, attrs): self.__init_handle_by_constructor__( ffi.AttrPattern, pattern, attrs) + @register_df_node class DominatorPattern(DFPattern): - """Get index-th item from a TuplePattern. + """Match a domination graph. Parameters ---------- parent: tvm.relay.df_pattern.DFPattern - The root of domination + The parent, i.e., the single node which produces something, + later aggregated by the child path: tvm.relay.df_pattern.DFPattern - The fuzzy path pattern between parent and child + The fuzzy path pattern between parent and child, + typically matches elementwise ops child: tvm.relay.df_pattern.DFPattern - The last node in the domination + The last node in the domination which is the end user + for all nodes in the path and the parent """ - def __init__(self, parent, path, child): + + def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern): self.__init_handle_by_constructor__( ffi.DominatorPattern, parent, path, child) + class DFPatternCallback(Object): - def __init__(self, pattern, callback): - self.__init_handle_by_constructor__( - ffi.DFPatternCallback, pattern, callback) + """A Callback for Pattern Rewriting -def is_input(name="") -> DFPattern: - return VarPattern(name) + When rewrite is called on this DFPatternCallback, the backend will find matches for the + pattern, call the callback function, and replace the matched expression with whatever + the callback returns. -def is_op(op_name: str) -> DFPattern: - op = get(op_name) - return ExprPattern(op) + Parameters + ---------- + pattern: tvm.relay.df_pattern.DFPattern + The Pattern to match + callback: PackedFunc + The callback function. + """ -def wildcard() -> DFPattern: - return WildcardPattern() + def __init__(self, pattern: DFPattern, callback): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback) -def has_type(ttype, pattern=None): - if pattern is None: - pattern = wildcard() - return TypePattern(pattern, ttype) + def rewrite(self, expr: Expr) -> Expr: + """ + Rewrite expression with this callback -def has_attr(attr_name, attr_value, pattern=None): - if pattern is None: - pattern = wildcard() - return patter.has_attr(attr_name, attr_value) + Parameters + ---------- + expr : tvm.relay.Expr + The expression to rewrite. + """ + return rewrite(self, expr) -def match(pattern, expr): - return ffi.match(pattern, expr) -def rewrite(callbacks, expr): +def rewrite(callbacks, expr: Expr) -> Expr: + """ + Rewrite expression with the given callbacks + + Parameters + ---------- + callbacks: tvm.relay.df_pattern.DFPatternCallback + The input callback or list of callbacks. + expr : tvm.relay.Expr + The expression to rewrite. + """ if isinstance(callbacks, DFPatternCallback): callbacks = [callbacks] return ffi.rewrite(callbacks, expr) - -def dominates(parent, path, child): - return DominatorPattern(parent, path, child) From 332446138f4f2324391be9ae1fb34dde0001edb3 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 09:31:00 -0700 Subject: [PATCH 18/46] remove the unused dataflow mutator --- include/tvm/relay/dataflow_functor.h | 19 ----- src/relay/ir/dataflow_functor.cc | 111 --------------------------- 2 files changed, 130 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index fb2af5546759..8c3eda807f46 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -138,25 +138,6 @@ class DFPatternVisitor : public DFPatternFunctor { std::unordered_set visited_; }; -class DFPatternMutator : public DFPatternFunctor { - public: - virtual DFPattern Mutate(const DFPattern& pattern); - DFPattern VisitDFPattern(const DFPattern& pattern) override; - DFPattern VisitDFPattern_(const AltPatternNode* op) override; - DFPattern VisitDFPattern_(const AttrPatternNode* op) override; - DFPattern VisitDFPattern_(const CallPatternNode* op) override; - DFPattern VisitDFPattern_(const DominatorPatternNode* op) override; - DFPattern VisitDFPattern_(const ExprPatternNode* op) override; - DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override; - DFPattern VisitDFPattern_(const TuplePatternNode* op) override; - DFPattern VisitDFPattern_(const TypePatternNode* op) override; - DFPattern VisitDFPattern_(const VarPatternNode* op) override; - DFPattern VisitDFPattern_(const WildcardPatternNode* op) override; - - protected: - std::unordered_map memo_; -}; - template class IndexedGraph { public: diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index b428c12bdb0c..3d15995b0314 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -76,117 +76,6 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} -// DFPatternMutator - -DFPattern DFPatternMutator::Mutate(const DFPattern& pattern) { return VisitDFPattern(pattern); } - -DFPattern DFPatternMutator::VisitDFPattern(const DFPattern& pattern) { - auto it = this->memo_.find(pattern); - if (it != this->memo_.end()) { - return it->second; - } else { - auto new_pattern = DFPatternFunctor::VisitDFPattern(pattern); - memo_[pattern] = new_pattern; - return new_pattern; - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const AltPatternNode* op) { - auto new_left = Mutate(op->left); - auto new_right = Mutate(op->right); - - if (new_left.same_as(op->left) && new_right.same_as(op->right)) { - return GetRef(op); - } else { - return AltPatternNode::make(new_left, new_right); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const AttrPatternNode* op) { - auto new_pattern = Mutate(op->pattern); - if (new_pattern.same_as(op->pattern)) { - return GetRef(op); - } else { - return AttrPatternNode::make(new_pattern, op->attrs); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const CallPatternNode* op) { - auto new_op = Mutate(op->op); - bool unchanged = op->op.same_as(new_op); - tvm::Array call_args; - for (auto arg : op->args) { - auto new_arg = Mutate(arg); - call_args.push_back(new_arg); - unchanged &= arg.same_as(new_arg); - } - if (unchanged) { - return GetRef(op); - } else { - return CallPatternNode::make(new_op, call_args, op->attrs, op->type_args); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const DominatorPatternNode* op) { - auto new_parent = Mutate(op->parent); - auto new_path = Mutate(op->path); - auto new_child = Mutate(op->child); - if (op->parent.same_as(new_child) && op->parent.same_as(new_child) && - op->parent.same_as(new_child)) { - return GetRef(op); - } else { - return DominatorPatternNode::make(new_parent, new_path, new_child); - } -} - - -DFPattern DFPatternMutator::VisitDFPattern_(const ExprPatternNode* op) { - return GetRef(op); -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TupleGetItemPatternNode* op) { - auto new_tuple = Mutate(op->tuple); - if (new_tuple.same_as(op->tuple)) { - return GetRef(op); - } else { - return TupleGetItemPatternNode::make(op->tuple, op->index); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TuplePatternNode* op) { - bool unchanged = true; - tvm::Array fields; - for (auto field : op->fields) { - auto new_field = Mutate(field); - fields.push_back(new_field); - unchanged &= field.same_as(new_field); - } - if (unchanged) { - return GetRef(op); - } else { - return TuplePatternNode::make(fields); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const TypePatternNode* op) { - auto new_pattern = Mutate(op->pattern); - if (new_pattern.same_as(op->pattern)) { - return GetRef(op); - } else { - return TypePatternNode::make(new_pattern, op->type); - } -} - -DFPattern DFPatternMutator::VisitDFPattern_(const VarPatternNode* op) { - return GetRef(op); -} - -DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) { - return GetRef(op); -} - - - // IndexedGraph IndexedGraph CreateIndexedGraph(const Expr& expr) { From 53ce92679dee2cb96e1dc56d44ba52cebe6ea749 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 10:58:40 -0700 Subject: [PATCH 19/46] comment dataflow functor, fix lint --- include/tvm/relay/dataflow_functor.h | 50 ++++++++++++++++++++----- include/tvm/relay/dataflow_matcher.h | 2 + include/tvm/relay/dataflow_pattern.h | 19 ---------- python/tvm/relay/df_pattern/__init__.py | 4 +- src/relay/ir/dataflow_functor.cc | 39 ++++++++++++------- src/relay/ir/dataflow_matcher.cc | 2 +- 6 files changed, 72 insertions(+), 44 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 8c3eda807f46..f48928885950 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -120,6 +120,12 @@ class DFPatternFunctor { } }; +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ class DFPatternVisitor : public DFPatternFunctor { public: void VisitDFPattern(const DFPattern& pattern) override; @@ -135,28 +141,47 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const WildcardPatternNode* op) override; protected: + // set of already-visited nodes std::unordered_set visited_; }; +/*! + * \brief A Wrapper around a templated graph type + * Holds a forward-backward indexed representation of the graph and a dominator tree representation + * of the graph + * + * Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern + * and Expr with the same infrastructure. + * + * IndexedGraph should be instantiated thorught the CreateIndexedGraph utilities. + */ template class IndexedGraph { public: - struct Node; - struct Edge { - explicit Edge(const std::shared_ptr& sink) : sink_(sink) {} - std::shared_ptr sink_; - }; + /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ struct Node { + /*! \brief Node Constructor + * \param ref The input graph node + * \param index The index of the node in toplogoical order + */ Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + + /*! \brief The input node */ const T ref_; + /*! \brief The topological order index */ const size_t index_; + /*! \brief A boolean to determine if this node is external to the graph */ bool is_external_ = false; - std::vector> outputs_; + /*! \brief The forward outputs/users of the node */ + std::vector> outputs_; + /*! \brief The depth of the node in the dominator tree */ size_t depth_; + /*! \brief The dominator parent/final user of the outputs of this node */ std::shared_ptr dominator_parent_; }; + /*! \brief Construct the domination create of the index graph */ void PostDom() { for (size_t i = topological_order_.size(); i != 0; --i) { size_t index = i - 1; @@ -171,20 +196,25 @@ class IndexedGraph { } } } + /*! \brief Map of input nodes to IndexedGraph Nodes */ std::unordered_map, ObjectHash, ObjectEqual> node_map_; + /*! \brief Topological IndexedGraph Nodes */ std::vector> topological_order_; protected: - std::shared_ptr LeastCommonAncestor(const std::vector>& outputs) { + /*! \brief Find the least common ancestor of all outputs of a node */ + std::shared_ptr LeastCommonAncestor(const std::vector>& outputs) { if (outputs.size() == 0) { return nullptr; } - auto parent = outputs.at(0)->sink_; + auto parent = outputs.at(0); for (size_t i = 1; i < outputs.size(); ++i) { - parent = LeastCommonAncestor(parent, outputs.at(i)->sink_); + parent = LeastCommonAncestor(parent, outputs.at(i)); } return parent; } + + /*! \brief Find the least common ancestor of two nodes */ std::shared_ptr LeastCommonAncestor(std::shared_ptr lhs, std::shared_ptr rhs) { if (lhs == nullptr || rhs == nullptr) { return nullptr; @@ -203,7 +233,9 @@ class IndexedGraph { } }; +/*! \brief Create an Indexed Graph based on an Expr */ IndexedGraph CreateIndexedGraph(const Expr& expr); +/*! \brief Create an Indexed Graph based on an DFPattern */ IndexedGraph CreateIndexedGraph(const DFPattern& pattern); } // namespace relay diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index b81e5996b8db..9bf7e15b0933 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -39,7 +39,9 @@ class DFPatternCallback; */ class DFPatternCallbackNode : public Object { public: + /*! \brief Pattern this callback matches */ DFPattern pattern_; + /*! \brief Function to call when finding a matched expression */ PackedFunc function_; void VisitAttrs(tvm::AttrVisitor* v) {} diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 5845778924e6..52ed9fe5f4ce 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -276,25 +276,6 @@ class WildcardPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); }; -/*! - * \brief Null Pattern. - */ -class NullPatternNode : public DFPatternNode { - public: - void VisitAttrs(tvm::AttrVisitor* v) {} - - static constexpr const char* _type_key = "relay.df_pattern.NullPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(NullPatternNode, DFPatternNode); -}; - -/*! - * \brief A pattern which matches anything. - */ -class NullPattern : public DFPattern { - public: - TVM_DEFINE_OBJECT_REF_METHODS(NullPattern, DFPattern, NullPatternNode); -}; - class TypePattern; /*! * \brief Pattern for Types. diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index d51b3bc5af2d..edae2d6c32bc 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """The Relay Pattern Language and tooling.""" +from tvm.relay import Expr from ...ir.base import Node from ...ir import make_node from ...runtime import Object from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi -from tvm.relay import Expr def register_df_node(type_key=None): @@ -107,7 +107,7 @@ def dominates(self, parent, path=None): path: tvm.relay.df_pattern.DFPattern The fuzzy path pattern. """ - if path == None: + if path is None: path = wildcard() return DominatorPattern(parent, path, self) diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index 3d15995b0314..e80d888e12ba 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -23,8 +23,8 @@ */ #include -#include #include +#include #include namespace tvm { @@ -80,6 +80,7 @@ void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} IndexedGraph CreateIndexedGraph(const Expr& expr) { using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ class Creator : public MixedModeVisitor { public: IndexedGraph CreateGraph(const Expr& expr) { @@ -98,22 +99,30 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { IndexedGraph graph_; size_t index_ = 0; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ class Annotator : public ExprFunctor { public: Annotator(const IndexedGraph& graph) : graph_(graph) {} IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs for (const auto& node : graph_.topological_order_) { ExprFunctor::VisitExpr(node->ref_, nullptr); } + // do the dominator analysis graph_.PostDom(); return std::move(graph_); } + /*! Default visitation pushes the parent to the child's ouputs */ void VisitExpr(const Expr& expr, NodePtr parent) override { auto current = graph_.node_map_[expr]; if (parent) { - auto edge = std::make_shared::Edge>(parent); - current->outputs_.push_back(edge); + current->outputs_.push_back(parent); } } @@ -214,6 +223,7 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ class Creator : public DFPatternVisitor { public: IndexedGraph CreateGraph(const DFPattern& pattern) { @@ -232,22 +242,30 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { IndexedGraph graph_; size_t index_ = 0; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ class Annotator : public DFPatternFunctor { public: Annotator(const IndexedGraph& graph) : graph_(graph) {} IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs for (const auto& node : graph_.topological_order_) { DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); } graph_.PostDom(); + // do the dominator analysis return std::move(graph_); } + /*! Default visitation pushes the parent to the child's ouputs */ void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { auto current = graph_.node_map_[pattern]; if (parent) { - auto edge = std::make_shared::Edge>(parent); - current->outputs_.push_back(edge); + current->outputs_.push_back(parent); } } @@ -268,8 +286,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); } } - void VisitDFPattern_(const DominatorPatternNode* op, - NodePtr parent) override { + void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); @@ -277,8 +294,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} - void VisitDFPattern_(const TupleGetItemPatternNode* op, - NodePtr parent) override { + void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); } @@ -294,13 +310,10 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override { - } + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} }; return Annotator(Creator().CreateGraph(pattern)).Annotate(); } - - } // namespace relay } // namespace tvm diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e0cb26d0d8c8..0ec56535e268 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -352,7 +352,7 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") class PatternRewriter : protected MixedModeMutator { public: - PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} + explicit PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} Expr Rewrite(const Expr& pre) { return this->VisitExpr(pre); } protected: From 88a67907693116da81d2675c1d8fb9a9e693242b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 13:17:09 -0700 Subject: [PATCH 20/46] add rfc as doc --- docs/langref/index.rst | 1 + docs/langref/relay_pattern.rst | 143 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 docs/langref/relay_pattern.rst diff --git a/docs/langref/index.rst b/docs/langref/index.rst index 0d296118da26..dcea9fa50c3d 100644 --- a/docs/langref/index.rst +++ b/docs/langref/index.rst @@ -46,6 +46,7 @@ algebraic data types, and operators in Relay, respectively. relay_type relay_adt relay_op + relay_pattern Hybrid Script ------------- diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst new file mode 100644 index 000000000000..cf47e9ca51d9 --- /dev/null +++ b/docs/langref/relay_pattern.rst @@ -0,0 +1,143 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +========================= +Pattern Matching in Relay +========================= + +There are many places in TVM where we identify pure data-flow sub-graphs of the Relay program and attempt to transform them in some way example passes include fusion, quantization, external code generation, and device specific optimizations such as bitpacking, and layer slicing used by VTA. + +Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match. + +Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parametrized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform. + +In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box. + +Recently there has been lots of discussion on similar issues in the community, and we wanted to gather feedback and hopefully collaborate on a design that can benefit everyone working in this space. This RFC focuses on the pattern language with future applications to come later. + +Examples +======== + +There are quite a few properties that are worth matching of operators below we examine how to match tree properties, and expand on some use cases that are not fully explored in the prototype. The first example is a simple case where we want to match one operator with a single input OR another operator with a single input, see the below diagram for a graphical representation and corresponding code:: + + def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + +The next example is a dense operation with any operator that is marked element-wise:: + + def test_no_match_attr(): + op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + +The next example is matching a diamond with two inputs at the top of the diamond:: + + def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +The final example we would like to match which is not yet implemented in the prototype is matching diamonds with a post-dominator relationship. Our plan is to embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d:: + + def test_match_dom_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +Design +====== + +The pattern language proposed is designed to be a mirror of Relay's IR with additional support for common scenarios. The goal of the pattern language is to provide a regular-expression like capability for matching data-flow graphs and doing rewriting. + +The high level design is to introduce a language of patterns for now we propose the language as:: + + Pattern ::= expr + | * + | pattern(pattern1, ... patternN) + | has_type(pattern, type) + | has_attr(pattern, attr, attr_value) + | is_input(name) + | pattern1 `|` pattern2 + | dominates(parent_pattern, path_pattern, child_pattern) + +The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern. + +Expression Pattern +****************** + +Match a literal expression. + +Wildcard +****************** + +Match any expression. + +Type Pattern +****************** + +Check that the expression matched by the nested pattern has a particular type. + +Attribute Pattern +****************** + +Check that the operator matched by the pattern has an attribute with a particular value. + +Input +****************** + +Check that the expression is an input, i.e has no parents and is a variable. + + +Alternate +****************** + +Either match the first pattern or the second pattern. + +Domination +****************** + +Match the parent pattern for the route node and then check that the child pattern holds for each child along the domination path. From 43fb8dc04a7b05c23411d071905c47336c841c90 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 16:31:18 -0700 Subject: [PATCH 21/46] fix some edge cases with the dominator pattern --- include/tvm/relay/dataflow_functor.h | 13 +++-- src/relay/ir/dataflow_functor.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 74 +++++++++++++++++++-------- tests/python/relay/test_df_pattern.py | 49 +++++++++++++----- 4 files changed, 99 insertions(+), 41 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index f48928885950..40a883d7f254 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -174,18 +174,20 @@ class IndexedGraph { /*! \brief A boolean to determine if this node is external to the graph */ bool is_external_ = false; /*! \brief The forward outputs/users of the node */ - std::vector> outputs_; + std::vector outputs_; /*! \brief The depth of the node in the dominator tree */ size_t depth_; /*! \brief The dominator parent/final user of the outputs of this node */ - std::shared_ptr dominator_parent_; + Node* dominator_parent_; + /*! \brief The nodes this node dominates */ + std::vector dominator_children_; }; /*! \brief Construct the domination create of the index graph */ void PostDom() { for (size_t i = topological_order_.size(); i != 0; --i) { size_t index = i - 1; - auto current = topological_order_[index]; + auto* current = topological_order_[index].get(); if (current->is_external_) { current->depth_ = 1; current->dominator_parent_ = nullptr; @@ -193,6 +195,7 @@ class IndexedGraph { auto parent = LeastCommonAncestor(current->outputs_); current->depth_ = parent ? parent->depth_ + 1 : 1; current->dominator_parent_ = parent; + parent->dominator_children_.push_back(current); } } } @@ -203,7 +206,7 @@ class IndexedGraph { protected: /*! \brief Find the least common ancestor of all outputs of a node */ - std::shared_ptr LeastCommonAncestor(const std::vector>& outputs) { + Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { return nullptr; } @@ -215,7 +218,7 @@ class IndexedGraph { } /*! \brief Find the least common ancestor of two nodes */ - std::shared_ptr LeastCommonAncestor(std::shared_ptr lhs, std::shared_ptr rhs) { + Node* LeastCommonAncestor(Node* lhs, Node* rhs) { if (lhs == nullptr || rhs == nullptr) { return nullptr; } diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index e80d888e12ba..ffbb4b342756 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -122,7 +122,7 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { void VisitExpr(const Expr& expr, NodePtr parent) override { auto current = graph_.node_map_[expr]; if (parent) { - current->outputs_.push_back(parent); + current->outputs_.push_back(parent.get()); } } @@ -265,7 +265,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { auto current = graph_.node_map_[pattern]; if (parent) { - current->outputs_.push_back(parent); + current->outputs_.push_back(parent.get()); } } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 0ec56535e268..251d3bfcaaa8 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -52,7 +52,7 @@ class DFPatternMatcher : public DFPatternFunctor memo_; std::vector matched_nodes_; -}; + }; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); @@ -98,6 +98,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons switch (op_map[op].type_code()) { case kDLInt: if (auto* val = kv.second.as()) { + std::cout << op << " " << op_map[op].operator int64_t() << std::endl; matches = val->value == op_map[op].operator int64_t(); } break; @@ -249,32 +250,61 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { auto watermark = matched_nodes_.size(); + auto backup_memo = memo_; + auto backup_matched_nodes = matched_nodes_; + if (VisitDFPattern(op->child, expr)) { - bool matches = true; - std::unordered_set dominated_exprs; - auto child_graph = CreateIndexedGraph(op->child); - for (auto node : child_graph.topological_order_) { - if (node->ref_.as()) { - continue; - } - if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) { - dominated_exprs.insert(memo_[node->ref_]); - } - } - ClearMap(watermark); + auto child_graph = CreateIndexedGraph(GetRef(op)); auto expr_graph = CreateIndexedGraph(expr); - for (auto node : expr_graph.topological_order_) { - if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) { - if (dominated_exprs.count(node->ref_) == 0) { - bool node_matches = VisitDFPattern(op->parent, node->ref_); - ClearMap(watermark); - matches = node_matches || VisitDFPattern(op->path, node->ref_); - ClearMap(watermark); - if (!matches) { - return false; + auto find_dominated = [&child_graph, this](const DFPattern& node) { + std::unordered_set dominated_exprs; + auto indexed_node = child_graph.node_map_[node]; + for (auto dominated : indexed_node->dominator_children_) { + if (dominated->ref_.as() || dominated->ref_.as()) { + continue; + } + dominated_exprs.insert(memo_[dominated->ref_]); + } + return dominated_exprs; + }; + std::function&)> + find_parent; + find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated, + &expr_graph, &find_parent]( + const Expr& expr, + const std::unordered_set& dominated_exprs) { + bool out = true; + for (auto node : expr_graph.node_map_[expr]->dominator_children_) { + if (out && dominated_exprs.count(node->ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref_)) { + backup_memo[op->parent] = memo_.at(op->parent); + backup_matched_nodes.push_back(op->parent); + memo_ = backup_memo; + matched_nodes_ = backup_matched_nodes; + watermark += 1; + return true; + } else { + if (VisitDFPattern(op->path, node->ref_)) { + auto new_dominated_exprs = find_dominated(op->path); + std::cout << watermark << std::endl; + ClearMap(watermark); + out &= find_parent(node->ref_, new_dominated_exprs); + } else { + out = false; + } } } } + return out; + }; + + auto dominated_exprs = find_dominated(op->child); + ClearMap(watermark); + bool matches = find_parent(expr, dominated_exprs); + if (matches) { + backup_memo[op->parent] = memo_.at(op->parent); + backup_memo[op->child] = expr; + memo_ = backup_memo; } return matches; } diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 40dd4ef27463..2fbe007616dc 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -247,12 +247,6 @@ def test_match_dominator(): # Check assert diamond.match(out) - # Pattern - is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) - reduction = is_op('add')(wildcard(), wildcard()) - diamond = dominates(is_conv2d, is_unary_elemwise, reduction) - # Expr inp = relay.var('input') weight = relay.var('weight') @@ -266,12 +260,39 @@ def test_match_dominator(): # Check assert diamond.match(out) + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert diamond.match(out) + + def test_not_match_dominator(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + # Expr + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + # Expr inp = relay.var('input') weight = relay.var('weight') @@ -284,12 +305,6 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) - # Pattern - is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) - reduction = is_op('add')(wildcard(), wildcard()) - diamond = dominates(is_conv2d, is_unary_elemwise, reduction) - # Expr inp = relay.var('input') weight = relay.var('weight') @@ -301,6 +316,16 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) + # Expr + inp = relay.var('input') + relu = relay.op.nn.relu(inp) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert not diamond.match(out) + def test_rewrite(): x = relay.var('x') y = relay.var('y') From 9bad5c16d8f2320565bf422e66c7d586aedf4a88 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 14 Apr 2020 16:35:19 -0700 Subject: [PATCH 22/46] respond to PR comments --- src/relay/ir/dataflow_matcher.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 251d3bfcaaa8..3939616bbfbf 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -66,6 +66,7 @@ void DFPatternMatcher::ClearMap(size_t watermark) { } matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); } + bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { if (memo_.count(pattern)) { return expr.same_as(memo_[pattern]); @@ -85,6 +86,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } + bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { bool matches = false; if (const auto* op_node = expr.as()) { @@ -121,7 +123,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons return matches; } -Array reverse(const Array args) { +Array reverse(const Array& args) { Array new_args; for (auto it = args.rbegin(); it != args.rend(); ++it) { new_args.push_back(*it); @@ -248,6 +250,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } return false; } + bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { auto watermark = matched_nodes_.size(); auto backup_memo = memo_; @@ -310,9 +313,11 @@ bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Exp } return false; } + bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { return StructuralEqual()(op->expr, expr); } + bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_get_item_node = expr.as()) { @@ -321,6 +326,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const } return matches; } + bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_node = expr.as()) { @@ -335,6 +341,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e } return matches; } + Expr InferType(const Expr& expr) { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); @@ -344,10 +351,12 @@ Expr InferType(const Expr& expr) { return mod->Lookup("main").as()->body; } } + bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } + bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* var_node = expr.as()) { @@ -358,6 +367,7 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp } return matches; } + bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { return true; } @@ -398,6 +408,7 @@ class PatternRewriter : protected MixedModeMutator { } return out; } + DFPatternMatcher matcher_; Array callbacks_; }; From 90f99fea2646cab0ece5f7bce8a796062cbc770b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 15 Apr 2020 13:43:17 -0700 Subject: [PATCH 23/46] more reviewer comments. Thanks masahi! --- include/tvm/relay/dataflow_functor.h | 10 +++--- include/tvm/relay/dataflow_pattern.h | 8 +++-- src/relay/ir/dataflow_matcher.cc | 2 -- tests/python/relay/test_df_pattern.py | 46 +++++++++++++-------------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 40a883d7f254..85b825f2ceb4 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -150,10 +150,10 @@ class DFPatternVisitor : public DFPatternFunctor { * Holds a forward-backward indexed representation of the graph and a dominator tree representation * of the graph * - * Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern - * and Expr with the same infrastructure. + * This class is templated and the implementaiton is in the header file so we can analyze both + * DFPattern and Expr with the same infrastructure. * - * IndexedGraph should be instantiated thorught the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. */ template class IndexedGraph { @@ -162,7 +162,7 @@ class IndexedGraph { struct Node { /*! \brief Node Constructor * \param ref The input graph node - * \param index The index of the node in toplogoical order + * \param index The index of the node in toplogical order */ Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} @@ -183,7 +183,7 @@ class IndexedGraph { /*! \brief The nodes this node dominates */ std::vector dominator_children_; }; - /*! \brief Construct the domination create of the index graph */ + /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { for (size_t i = topological_order_.size(); i != 0; --i) { size_t index = i - 1; diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 52ed9fe5f4ce..b151c13ed0be 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -308,7 +308,7 @@ class TypePattern : public DFPattern { class AttrPattern; /*! - * \brief Pattern for Types. + * \brief Pattern for Attributes. */ class AttrPatternNode : public DFPatternNode { public: @@ -329,7 +329,7 @@ class AttrPatternNode : public DFPatternNode { }; /*! - * \brief A pattern which matches a type in another pattern + * \brief A pattern which matches attributes in another pattern */ class AttrPattern : public DFPattern { public: @@ -338,7 +338,9 @@ class AttrPattern : public DFPattern { class DominatorPattern; /*! - * \brief Pattern for Types. + * \brief Dominated Graph Pattern + * Pattern for fuzzy subgraphs where all outputs of the parent are used finally by the child, and + * every operation between the parent and the child matches the path. */ class DominatorPatternNode : public DFPatternNode { public: diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3939616bbfbf..9b85cc7f92de 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -100,7 +100,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons switch (op_map[op].type_code()) { case kDLInt: if (auto* val = kv.second.as()) { - std::cout << op << " " << op_map[op].operator int64_t() << std::endl; matches = val->value == op_map[op].operator int64_t(); } break; @@ -289,7 +288,6 @@ bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Exp } else { if (VisitDFPattern(op->path, node->ref_)) { auto new_dominated_exprs = find_dominated(op->path); - std::cout << watermark << std::endl; ClearMap(watermark); out &= find_parent(node->ref_, new_dominated_exprs); } else { diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 2fbe007616dc..b921537e42df 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -548,27 +548,27 @@ def test_algebraic_simplify(): assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) if __name__ == "__main__": - #test_match_op() - #test_no_match_op() - #test_match_op_or() - #test_match_call() - #test_no_match_call() - #test_match_call_commutive() - #test_no_match_call_commutive() - #test_match_tuple() - #test_no_match_tuple() - #test_match_type() - #test_no_match_type() - #test_match_attr() - #test_no_match_attr() - #test_match_diamond() - #test_no_match_diamond() - #test_match_fake_diamond() - #test_rewrite() - #test_fuse_batchnorm() - #test_no_fuse_batchnorm() - #test_fuse_double_batchnorm() - #test_partial_fuse_double_batchnorm() - #test_fuse_batchnorm_commutation() - #test_match_dominator() + test_match_op() + test_no_match_op() + test_match_op_or() + test_match_call() + test_no_match_call() + test_match_call_commutive() + test_no_match_call_commutive() + test_match_tuple() + test_no_match_tuple() + test_match_type() + test_no_match_type() + test_match_attr() + test_no_match_attr() + test_match_diamond() + test_no_match_diamond() + test_match_fake_diamond() + test_rewrite() + test_fuse_batchnorm() + test_no_fuse_batchnorm() + test_fuse_double_batchnorm() + test_partial_fuse_double_batchnorm() + test_fuse_batchnorm_commutation() + test_match_dominator() test_not_match_dominator() From c8a56de7bbf38ee420883e68b1044bb6cb3e5ed1 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 16 Apr 2020 08:16:07 -0700 Subject: [PATCH 24/46] Refactor Dominator Matching --- src/relay/ir/dataflow_matcher.cc | 135 +++++++++++++++----------- tests/python/relay/test_df_pattern.py | 30 ++++-- 2 files changed, 100 insertions(+), 65 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 9b85cc7f92de..f4356cae892c 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -32,8 +32,12 @@ namespace relay { // Pattern Matcher + +class DominatorMatcher; + class DFPatternMatcher : public DFPatternFunctor { public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} bool Match(const DFPattern& pattern, const Expr& expr); protected: @@ -52,7 +56,9 @@ class DFPatternMatcher : public DFPatternFunctor memo_; std::vector matched_nodes_; - }; + IndexedGraph expr_graph_; + friend DominatorMatcher; +}; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); @@ -250,66 +256,78 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } -bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { - auto watermark = matched_nodes_.size(); - auto backup_memo = memo_; - auto backup_matched_nodes = matched_nodes_; - - if (VisitDFPattern(op->child, expr)) { - auto child_graph = CreateIndexedGraph(GetRef(op)); - auto expr_graph = CreateIndexedGraph(expr); - auto find_dominated = [&child_graph, this](const DFPattern& node) { - std::unordered_set dominated_exprs; - auto indexed_node = child_graph.node_map_[node]; - for (auto dominated : indexed_node->dominator_children_) { - if (dominated->ref_.as() || dominated->ref_.as()) { - continue; - } - dominated_exprs.insert(memo_[dominated->ref_]); +// Friend class to do recursive dominator matching +class DominatorMatcher { + public: + DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr) + : matcher_(matcher), op_(op), expr_(expr) { + watermark_ = matcher_->matched_nodes_.size(); + pattern_graph_ = CreateIndexedGraph(GetRef(op)); + } + bool Match() { + if (matcher_->VisitDFPattern(op_->child, expr_)) { + auto dominated_exprs = FindDominated(op_->child); + matcher_->ClearMap(watermark_); + + bool matches = FindParent(expr_, dominated_exprs); + if (matches) { + matcher_->ClearMap(watermark_); + matcher_->memo_[op_->child] = expr_; + matcher_->matched_nodes_.push_back(op_->child); } - return dominated_exprs; - }; - std::function&)> - find_parent; - find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated, - &expr_graph, &find_parent]( - const Expr& expr, - const std::unordered_set& dominated_exprs) { - bool out = true; - for (auto node : expr_graph.node_map_[expr]->dominator_children_) { - if (out && dominated_exprs.count(node->ref_) == 0) { - if (VisitDFPattern(op->parent, node->ref_)) { - backup_memo[op->parent] = memo_.at(op->parent); - backup_matched_nodes.push_back(op->parent); - memo_ = backup_memo; - matched_nodes_ = backup_matched_nodes; - watermark += 1; - return true; + return matches; + } + return false; + } + + protected: + DFPatternMatcher* matcher_; + const DominatorPatternNode* op_; + IndexedGraph pattern_graph_; + Expr expr_; + size_t watermark_; + + std::unordered_set FindDominated(const DFPattern& node) { + std::unordered_set dominated_exprs; + auto indexed_node = pattern_graph_.node_map_[node]; + for (auto dominated : indexed_node->dominator_children_) { + if (dominated->ref_.as()) { + continue; + } + if (matcher_->memo_.count(dominated->ref_)) { + dominated_exprs.insert(matcher_->memo_[dominated->ref_]); + } + } + return dominated_exprs; + } + bool FindParent(const Expr& expr, + const std::unordered_set& dominated_exprs) { + bool out = true; + for (auto node : matcher_->expr_graph_.node_map_[expr]->dominator_children_) { + if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as() == nullptr) { + if (matcher_->VisitDFPattern(op_->parent, node->ref_)) { + matcher_->ClearMap(watermark_); + matcher_->memo_[op_->parent] = node->ref_; + matcher_->matched_nodes_.push_back(op_->parent); + watermark_ += 1; + return true; + } else { + if (matcher_->VisitDFPattern(op_->path, node->ref_)) { + auto new_dominated_exprs = FindDominated(op_->path); + matcher_->ClearMap(watermark_); + out &= FindParent(node->ref_, new_dominated_exprs); } else { - if (VisitDFPattern(op->path, node->ref_)) { - auto new_dominated_exprs = find_dominated(op->path); - ClearMap(watermark); - out &= find_parent(node->ref_, new_dominated_exprs); - } else { - out = false; - } + out = false; } } } - return out; - }; - - auto dominated_exprs = find_dominated(op->child); - ClearMap(watermark); - bool matches = find_parent(expr, dominated_exprs); - if (matches) { - backup_memo[op->parent] = memo_.at(op->parent); - backup_memo[op->child] = expr; - memo_ = backup_memo; } - return matches; + return out; } - return false; +}; + +bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { + return DominatorMatcher(this, op, expr).Match(); } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { @@ -371,7 +389,7 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) { - return DFPatternMatcher().Match(pattern, expr); + return DFPatternMatcher(expr).Match(pattern, expr); }); // Rewrite @@ -390,7 +408,8 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") class PatternRewriter : protected MixedModeMutator { public: - explicit PatternRewriter(const Array& callbacks) : callbacks_(callbacks) {} + explicit PatternRewriter(const Array& callbacks, const Expr& root_expr) + : callbacks_(callbacks), matcher_(DFPatternMatcher(root_expr)) {} Expr Rewrite(const Expr& pre) { return this->VisitExpr(pre); } protected: @@ -407,12 +426,12 @@ class PatternRewriter : protected MixedModeMutator { return out; } - DFPatternMatcher matcher_; Array callbacks_; + DFPatternMatcher matcher_; }; Expr RewritePatterns(Array callbacks, Expr expr) { - return PatternRewriter(callbacks).Rewrite(expr); + return PatternRewriter(callbacks, expr).Rewrite(expr); } TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index b921537e42df..c70305946e8b 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -235,7 +235,7 @@ def test_match_dominator(): reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) - # Expr + # Classic Diamond inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) @@ -247,7 +247,7 @@ def test_match_dominator(): # Check assert diamond.match(out) - # Expr + # Deeper Branch inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) @@ -260,7 +260,7 @@ def test_match_dominator(): # Check assert diamond.match(out) - # Expr + # Single Branch inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) @@ -272,6 +272,22 @@ def test_match_dominator(): # Check assert diamond.match(out) + # Fuzzy path/nested Diamond + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + assert diamond.match(out) def test_not_match_dominator(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -279,7 +295,7 @@ def test_not_match_dominator(): reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) - # Expr + # Fake Diamond input1 = relay.var('input1') weight1 = relay.var('weight1') conv2d1 = relay.op.nn.conv2d(input1, weight1) @@ -293,7 +309,7 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) - # Expr + # Add op that doesn't match K_ELEMWISE inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) @@ -305,7 +321,7 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) - # Expr + # Relu on the input instead of the conv inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) @@ -316,7 +332,7 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) - # Expr + # No conv inp = relay.var('input') relu = relay.op.nn.relu(inp) relu = relay.op.nn.relu(relu) From e2993854c4c7505831fec5e56a15596858136917 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 16 Apr 2020 14:49:18 -0700 Subject: [PATCH 25/46] Extend Memoization Allows disabling memoization and storing multiple matches when memoization is disabled. --- include/tvm/relay/dataflow_functor.h | 2 + src/relay/ir/dataflow_functor.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 69 +++++++++++++++------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index 85b825f2ceb4..a0edeb224f62 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -173,6 +173,8 @@ class IndexedGraph { /*! \brief A boolean to determine if this node is external to the graph */ bool is_external_ = false; + /*! \brief The forward inputs of the node */ + std::vector inputs_; /*! \brief The forward outputs/users of the node */ std::vector outputs_; diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index ffbb4b342756..c367b228d4b0 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -118,11 +118,13 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { return std::move(graph_); } - /*! Default visitation pushes the parent to the child's ouputs */ + /*! Default visitation pushes the parent to the child's ouputs and the child to the parent's + * inputs*/ void VisitExpr(const Expr& expr, NodePtr parent) override { auto current = graph_.node_map_[expr]; if (parent) { current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); } } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index f4356cae892c..20f55fa245fb 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -39,7 +39,7 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; @@ -54,10 +54,11 @@ class DFPatternMatcher : public DFPatternFunctor memo_; + std::unordered_map, ObjectHash, ObjectEqual> memo_; std::vector matched_nodes_; IndexedGraph expr_graph_; friend DominatorMatcher; + bool memoize_ = true; }; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { @@ -74,13 +75,14 @@ void DFPatternMatcher::ClearMap(size_t watermark) { } bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { - if (memo_.count(pattern)) { - return expr.same_as(memo_[pattern]); + if (memoize_ && memo_.count(pattern)) { + CHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); } else { auto watermark = matched_nodes_.size(); auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { - memo_[pattern] = expr; + memo_[pattern].push_back(expr); matched_nodes_.push_back(pattern); } else { ClearMap(watermark); @@ -261,20 +263,14 @@ class DominatorMatcher { public: DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr) : matcher_(matcher), op_(op), expr_(expr) { - watermark_ = matcher_->matched_nodes_.size(); pattern_graph_ = CreateIndexedGraph(GetRef(op)); } bool Match() { if (matcher_->VisitDFPattern(op_->child, expr_)) { + matcher_->memoize_ = false; auto dominated_exprs = FindDominated(op_->child); - matcher_->ClearMap(watermark_); - bool matches = FindParent(expr_, dominated_exprs); - if (matches) { - matcher_->ClearMap(watermark_); - matcher_->memo_[op_->child] = expr_; - matcher_->matched_nodes_.push_back(op_->child); - } + matcher_->memoize_ = true; return matches; } return false; @@ -285,7 +281,6 @@ class DominatorMatcher { const DominatorPatternNode* op_; IndexedGraph pattern_graph_; Expr expr_; - size_t watermark_; std::unordered_set FindDominated(const DFPattern& node) { std::unordered_set dominated_exprs; @@ -295,7 +290,7 @@ class DominatorMatcher { continue; } if (matcher_->memo_.count(dominated->ref_)) { - dominated_exprs.insert(matcher_->memo_[dominated->ref_]); + dominated_exprs.insert(matcher_->memo_[dominated->ref_].back()); } } return dominated_exprs; @@ -305,16 +300,13 @@ class DominatorMatcher { bool out = true; for (auto node : matcher_->expr_graph_.node_map_[expr]->dominator_children_) { if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as() == nullptr) { + matcher_->memoize_ = true; if (matcher_->VisitDFPattern(op_->parent, node->ref_)) { - matcher_->ClearMap(watermark_); - matcher_->memo_[op_->parent] = node->ref_; - matcher_->matched_nodes_.push_back(op_->parent); - watermark_ += 1; return true; } else { + matcher_->memoize_ = false; if (matcher_->VisitDFPattern(op_->path, node->ref_)) { auto new_dominated_exprs = FindDominated(op_->path); - matcher_->ClearMap(watermark_); out &= FindParent(node->ref_, new_dominated_exprs); } else { out = false; @@ -408,30 +400,41 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") class PatternRewriter : protected MixedModeMutator { public: - explicit PatternRewriter(const Array& callbacks, const Expr& root_expr) - : callbacks_(callbacks), matcher_(DFPatternMatcher(root_expr)) {} - Expr Rewrite(const Expr& pre) { return this->VisitExpr(pre); } + PatternRewriter() {} + Expr Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + do { + last = post; + for (auto callback : callbacks) { + callback_ = &callback; + auto matcher = DFPatternMatcher(post); + matcher_ = &matcher; + memo_.clear(); + post = this->VisitExpr(post); + } + } while (last != post); + return post; + } protected: Expr DispatchVisitExpr(const Expr& pre) override { auto post = MixedModeMutator::DispatchVisitExpr(pre); - Expr out = post; - for (auto& callback : callbacks_) { - if (auto* callback_node = callback.as()) { - if (matcher_.Match(callback_node->pattern_, out)) { - out = callback_node->function_(pre, out); - } + if (auto* callback_node = callback_->as()) { + if (matcher_->Match(callback_node->pattern_, post)) { + return callback_node->function_(pre, post); } } - return out; + return post; } - Array callbacks_; - DFPatternMatcher matcher_; + DFPatternMatcher* matcher_ = nullptr; + DFPatternCallback* callback_ = nullptr; }; Expr RewritePatterns(Array callbacks, Expr expr) { - return PatternRewriter(callbacks, expr).Rewrite(expr); + return PatternRewriter().Rewrite(callbacks, expr); } TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); From 00f9db3b1c7324ee3361aa471406eea93e1082a6 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 16 Apr 2020 15:12:57 -0700 Subject: [PATCH 26/46] Fold DominatorMatcher back into DFPatternMatcher --- src/relay/ir/dataflow_matcher.cc | 99 +++++++++++++++----------------- 1 file changed, 45 insertions(+), 54 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 20f55fa245fb..3087f3e71896 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -54,10 +54,15 @@ class DFPatternMatcher : public DFPatternFunctor FindDominated(const DFPattern& node); + bool FindParent(const Expr& expr, + const std::unordered_set& dominated_exprs, + const DominatorPatternNode* op); + std::unordered_map, ObjectHash, ObjectEqual> memo_; std::vector matched_nodes_; IndexedGraph expr_graph_; - friend DominatorMatcher; + IndexedGraph pattern_graph_; bool memoize_ = true; }; @@ -258,68 +263,54 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } -// Friend class to do recursive dominator matching -class DominatorMatcher { - public: - DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr) - : matcher_(matcher), op_(op), expr_(expr) { - pattern_graph_ = CreateIndexedGraph(GetRef(op)); - } - bool Match() { - if (matcher_->VisitDFPattern(op_->child, expr_)) { - matcher_->memoize_ = false; - auto dominated_exprs = FindDominated(op_->child); - bool matches = FindParent(expr_, dominated_exprs); - matcher_->memoize_ = true; - return matches; +std::unordered_set DFPatternMatcher::FindDominated( + const DFPattern& node) { + std::unordered_set dominated_exprs; + auto indexed_node = pattern_graph_.node_map_[node]; + for (auto dominated : indexed_node->dominator_children_) { + if (dominated->ref_.as()) { + continue; } - return false; - } - - protected: - DFPatternMatcher* matcher_; - const DominatorPatternNode* op_; - IndexedGraph pattern_graph_; - Expr expr_; - - std::unordered_set FindDominated(const DFPattern& node) { - std::unordered_set dominated_exprs; - auto indexed_node = pattern_graph_.node_map_[node]; - for (auto dominated : indexed_node->dominator_children_) { - if (dominated->ref_.as()) { - continue; - } - if (matcher_->memo_.count(dominated->ref_)) { - dominated_exprs.insert(matcher_->memo_[dominated->ref_].back()); - } + if (memo_.count(dominated->ref_)) { + Array matched = memo_[dominated->ref_]; + dominated_exprs.insert(matched[matched.size() - 1]); } - return dominated_exprs; } - bool FindParent(const Expr& expr, - const std::unordered_set& dominated_exprs) { - bool out = true; - for (auto node : matcher_->expr_graph_.node_map_[expr]->dominator_children_) { - if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as() == nullptr) { - matcher_->memoize_ = true; - if (matcher_->VisitDFPattern(op_->parent, node->ref_)) { - return true; + return dominated_exprs; +} + +bool DFPatternMatcher::FindParent( + const Expr& expr, const std::unordered_set& dominated_exprs, + const DominatorPatternNode* op) { + bool out = true; + for (auto node : expr_graph_.node_map_[expr]->dominator_children_) { + if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as() == nullptr) { + memoize_ = true; + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + memoize_ = false; + if (VisitDFPattern(op->path, node->ref_)) { + auto new_dominated_exprs = FindDominated(op->path); + out &= FindParent(node->ref_, new_dominated_exprs, op); } else { - matcher_->memoize_ = false; - if (matcher_->VisitDFPattern(op_->path, node->ref_)) { - auto new_dominated_exprs = FindDominated(op_->path); - out &= FindParent(node->ref_, new_dominated_exprs); - } else { - out = false; - } + out = false; } } } - return out; } -}; - + return out; +} bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { - return DominatorMatcher(this, op, expr).Match(); + pattern_graph_ = CreateIndexedGraph(GetRef(op)); + if (VisitDFPattern(op->child, expr)) { + memoize_ = false; + auto dominated_exprs = FindDominated(op->child); + bool matches = FindParent(expr, dominated_exprs, op); + memoize_ = true; + return matches; + } + return false; } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { From 1bd45055ac30e17307568e9d09ed7385fa280e2e Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 16 Apr 2020 16:07:24 -0700 Subject: [PATCH 27/46] move InferType Function --- include/tvm/relay/transform.h | 9 +++++++++ src/relay/ir/dataflow_matcher.cc | 10 ---------- src/relay/transforms/type_infer.cc | 5 +++++ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 9a8ca8421997..62799fbaf299 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -375,6 +375,15 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); */ TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); +/*! + * \brief Infer the type of an expression base on it's inputs. + * + * \param expr the Expr. + * + * \return A type checked Expr with its checked_type field populated. + */ +TVM_DLL Expr InferType(const Expr& expr); + /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This * function is used as a helper function to rewrtie an expression in a pass. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3087f3e71896..af4886868a3d 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -341,16 +341,6 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e return matches; } -Expr InferType(const Expr& expr) { - auto mod = IRModule::FromExpr(expr); - mod = transform::InferType()(mod); - if (expr.as()) { - return mod->Lookup("main"); - } else { - return mod->Lookup("main").as()->body; - } -} - bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 078248483587..a7334eb09525 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -742,6 +742,11 @@ Function InferType(const Function& func, const IRModule& mod, const GlobalVar& v return Downcast(func_ret); } +Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + return InferType(expr, mod); +} + namespace transform { Pass InferType() { From ea4b76266fa57a13d3b1b24d46cf60403c46bce5 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 20 Apr 2020 09:28:19 -0700 Subject: [PATCH 28/46] respond to review comments --- src/relay/ir/dataflow_matcher.cc | 60 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index af4886868a3d..e12473676e35 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -40,6 +40,7 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; @@ -214,27 +215,24 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // associate divide/multiply if (is_pattern_op(op, "divide")) { if (const auto* arg_node = op->args[0].as()) { - if (is_pattern_op(arg_node, "multiply")) { - if (is_expr_op(expr, "multiply")) { - if (is_expr_op(call_node->args[0], "divide") || - is_expr_op(call_node->args[1], "divide")) { - bool out = false; - for (size_t arg_id = 0; arg_id < 2; ++arg_id) { - auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]}, - op->attrs, op->type_args); - auto mul = - CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, - arg_node->attrs, arg_node->type_args); - out = VisitDFPattern(mul, expr); - if (out) { - return out; - } else { - ClearMap(watermark); - } - } - return out; + if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") && + (is_expr_op(call_node->args[0], "divide") || + is_expr_op(call_node->args[1], "divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]}, + op->attrs, op->type_args); + auto mul = + CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, + arg_node->attrs, arg_node->type_args); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); } } + return out; } } } @@ -242,18 +240,15 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // associate multiply/divide for (size_t arg_id = 0; arg_id < 2; ++arg_id) { if (auto* arg_node = op->args[arg_id].as()) { - if (is_pattern_op(arg_node, "divide")) { - if (is_expr_op(expr, "divide")) { - if (is_expr_op(call_node->args[0], "multiply") || - is_expr_op(call_node->args[1], "multiply")) { - auto mul = - CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, - op->attrs, op->type_args); - auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]}, - arg_node->attrs, arg_node->type_args); - return VisitDFPattern(div, expr); - } - } + if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && + (is_expr_op(call_node->args[0], "multiply") || + is_expr_op(call_node->args[1], "multiply"))) { + auto mul = + CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, + op->attrs, op->type_args); + auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]}, + arg_node->attrs, arg_node->type_args); + return VisitDFPattern(div, expr); } } } @@ -294,13 +289,14 @@ bool DFPatternMatcher::FindParent( auto new_dominated_exprs = FindDominated(op->path); out &= FindParent(node->ref_, new_dominated_exprs, op); } else { - out = false; + return false; } } } } return out; } + bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { pattern_graph_ = CreateIndexedGraph(GetRef(op)); if (VisitDFPattern(op->child, expr)) { From 60ff4bef04d918dba1d62a1fe4bf673da3920178 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 20 Apr 2020 11:30:46 -0700 Subject: [PATCH 29/46] Clean up rewriter API --- python/tvm/relay/df_pattern/__init__.py | 43 +++-- src/relay/ir/dataflow_matcher.cc | 2 +- tests/python/relay/test_df_pattern.py | 208 +++++++++++------------- 3 files changed, 128 insertions(+), 125 deletions(-) diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index edae2d6c32bc..2280592eeb05 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -400,25 +400,16 @@ def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern): ffi.DominatorPattern, parent, path, child) -class DFPatternCallback(Object): +class DFPatternCallback: """A Callback for Pattern Rewriting When rewrite is called on this DFPatternCallback, the backend will find matches for the pattern, call the callback function, and replace the matched expression with whatever the callback returns. - Parameters - ---------- - pattern: tvm.relay.df_pattern.DFPattern - The Pattern to match - callback: PackedFunc - The callback function. + Users are expect to inherit from this class and provide a "self.pattern" to match """ - def __init__(self, pattern: DFPattern, callback): - self.__init_handle_by_constructor__( - ffi.DFPatternCallback, pattern, callback) - def rewrite(self, expr: Expr) -> Expr: """ Rewrite expression with this callback @@ -430,6 +421,27 @@ def rewrite(self, expr: Expr) -> Expr: """ return rewrite(self, expr) + def callback(self, pre, post, node_map): + """ + Callback function to use when we found a match to the pattern + + Parameters + ---------- + pre : tvm.relay.Expr + The matching expression from the original graph. + post : tvm.relay.Expr + The matching expression with rewritten inputs + node_map : Map(DFPattern, List(Expr)) + The map between patterns and matched expressions + """ + raise "Unimplemented" + +class _DFPatternCallback(Object): + """C++ implemenation""" + def __init__(self, pattern, callback): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback) + def rewrite(callbacks, expr: Expr) -> Expr: """ @@ -443,5 +455,10 @@ def rewrite(callbacks, expr: Expr) -> Expr: The expression to rewrite. """ if isinstance(callbacks, DFPatternCallback): - callbacks = [callbacks] - return ffi.rewrite(callbacks, expr) + tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] + else: + tmp = [] + for callback in callbacks: + tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) + + return ffi.rewrite(tmp, expr) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e12473676e35..7a2a61f71a31 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -400,7 +400,7 @@ class PatternRewriter : protected MixedModeMutator { auto post = MixedModeMutator::DispatchVisitExpr(pre); if (auto* callback_node = callback_->as()) { if (matcher_->Match(callback_node->pattern_, post)) { - return callback_node->function_(pre, post); + return callback_node->function_(pre, post, matcher_->GetMemo()); } } return post; diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index c70305946e8b..cf25a19e67b8 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -347,9 +347,12 @@ def test_rewrite(): y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) sub_pattern = is_op('subtract')(wildcard(), wildcard()) - def add_to_sub(pre, post): - return post.args[0] - post.args[1] - out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y) + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + out = rewrite(TestRewrite(), x + y) assert sub_pattern.match(out) def test_not_fuse_multi_diamond(): @@ -370,35 +373,25 @@ def test_not_fuse_multi_diamond(): # Check assert not diamond.match(out) -def fuse_batchnorm(pre, post): - def left_right_call(post): - if isinstance(post.args[0], relay.Call): - return (post.args[1], post.args[0]) - else: - return (post.args[0], post.args[1]) - - beta, post = left_right_call(post) - assert isinstance(post, relay.Call) - - if post.op == relay.op.get("divide"): - numerator = post.args[0] - denominator = post.args[1] - gamma, numerator = left_right_call(numerator) - elif post.op == relay.op.get("multiply"): - gamma, quotient = left_right_call(post) - numerator = quotient.args[0] - denominator = quotient.args[1] - else: - raise "Found unexcepted op" - - x = numerator.args[0] - mean = numerator.args[1] - - var = denominator.args[0].args[0] - eps = denominator.args[0].args[1] - - out = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item()) - return out[0] +class BatchnormCallback(DFPatternCallback): + def __init__(self): + self.x = wildcard() + self.var = wildcard() + self.mean = wildcard() + self.beta = wildcard() + self.gamma = wildcard() + self.eps = wildcard() + + self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta + + def callback(self, pre, post, node_map): + x = node_map[self.x][0] + var = node_map[self.var][0] + mean = node_map[self.mean][0] + beta = node_map[self.beta][0] + gamma = node_map[self.gamma][0] + eps = node_map[self.eps][0] + return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0] def test_fuse_batchnorm(): x = relay.var('x') @@ -407,10 +400,9 @@ def test_fuse_batchnorm(): beta = relay.var('beta') gamma = relay.var('gamma') - BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + out = rewrite(BatchnormCallback(), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) def test_no_fuse_batchnorm(): @@ -420,10 +412,9 @@ def test_no_fuse_batchnorm(): beta = relay.var('beta') gamma = relay.var('gamma') - BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), fake_BN) + out = rewrite(BatchnormCallback(), fake_BN) assert tvm.ir.structural_equal(out, fake_BN) def test_fuse_double_batchnorm(): @@ -433,11 +424,10 @@ def test_fuse_double_batchnorm(): beta = relay.var('beta') gamma = relay.var('gamma') - BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2) + out = rewrite(BatchnormCallback(), BN2) bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0] bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0] @@ -451,11 +441,10 @@ def test_partial_fuse_double_batchnorm(): beta = relay.var('beta') gamma = relay.var('gamma') - BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2) + out = rewrite(BatchnormCallback(), BN2) bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0] @@ -468,74 +457,70 @@ def test_fuse_batchnorm_commutation(): beta = relay.var('beta') gamma = relay.var('gamma') - BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard() #commute add BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + out = rewrite(BatchnormCallback(), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) # associate divide/multiply BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5)) + beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + out = rewrite(BatchnormCallback(), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) # associate multiply/divide BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta - out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN) + out = rewrite(BatchnormCallback(), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) def algebraic_simplify(expr): - pattern_callbacks = [] - - def elwise_zero_callback(pre, post): - if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | - tvm.ir.structural_equal(post.args[0], relay.const(0.0))): - return post.args[1] - else: - return post.args[0] - - def elwise_one_callback(pre, post): - if (tvm.ir.structural_equal(post.args[0], relay.const(1)) | - tvm.ir.structural_equal(post.args[0], relay.const(1.0))): - return post.args[1] - else: - return post.args[0] - - def return_zero_callback(pre, post): - if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | - tvm.ir.structural_equal(post.args[0], relay.const(0.0))): - return post.args[0] - else: - return post.args[1] - zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) - add_pattern = wildcard() + zero - pattern_callbacks.append(DFPatternCallback(add_pattern, elwise_zero_callback)) - - sub_pattern = wildcard() - zero - pattern_callbacks.append(DFPatternCallback(sub_pattern, elwise_zero_callback)) - - mul_pattern = wildcard() * one - pattern_callbacks.append(DFPatternCallback(mul_pattern, elwise_one_callback)) + class ElwiseNullCallback(DFPatternCallback): + def callback(self, pre, post, node_map): + return node_map[self.x][0] + + class AddCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x + zero - mul_zero_pattern = wildcard() * zero - pattern_callbacks.append(DFPatternCallback(mul_zero_pattern, return_zero_callback)) - - div_pattern = wildcard() / one - pattern_callbacks.append(DFPatternCallback(div_pattern, elwise_one_callback)) - - zero_div_pattern = zero / wildcard() - pattern_callbacks.append(DFPatternCallback(zero_div_pattern, return_zero_callback)) - - return rewrite(pattern_callbacks, expr); + class SubCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x - zero + + class MulCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x * one + + class DivCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x / one + + class MulZeroCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x * wildcard() + + class ZeroDivCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x / wildcard() + + return rewrite([AddCallback(), + SubCallback(), + MulCallback(), + DivCallback(), + MulZeroCallback(), + ZeroDivCallback() + ], expr); def test_algebraic_simplify(): x = relay.Var('x') y = relay.Var('y') - print(x + relay.const(0)) - one = relay.const(1) zero = relay.const(0) onef = relay.const(1.0) @@ -564,27 +549,28 @@ def test_algebraic_simplify(): assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) if __name__ == "__main__": - test_match_op() - test_no_match_op() - test_match_op_or() - test_match_call() - test_no_match_call() - test_match_call_commutive() - test_no_match_call_commutive() - test_match_tuple() - test_no_match_tuple() - test_match_type() - test_no_match_type() - test_match_attr() - test_no_match_attr() - test_match_diamond() - test_no_match_diamond() - test_match_fake_diamond() - test_rewrite() - test_fuse_batchnorm() - test_no_fuse_batchnorm() - test_fuse_double_batchnorm() - test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() - test_match_dominator() - test_not_match_dominator() + #test_match_op() + #test_no_match_op() + #test_match_op_or() + #test_match_call() + #test_no_match_call() + #test_match_call_commutive() + #test_no_match_call_commutive() + #test_match_tuple() + #test_no_match_tuple() + #test_match_type() + #test_no_match_type() + #test_match_attr() + #test_no_match_attr() + #test_match_diamond() + #test_no_match_diamond() + #test_match_fake_diamond() + #test_rewrite() + #test_fuse_batchnorm() + #test_no_fuse_batchnorm() + #test_fuse_double_batchnorm() + #test_partial_fuse_double_batchnorm() + #test_fuse_batchnorm_commutation() + #test_match_dominator() + #test_not_match_dominator() + test_algebraic_simplify() From 9b7fd472756e5ca3b27198f4a8bd855e13972a21 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 20 Apr 2020 11:32:37 -0700 Subject: [PATCH 30/46] fix review comment --- src/relay/ir/dataflow_matcher.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 7a2a61f71a31..aef6edc3db17 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -162,7 +162,6 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } return false; }; - auto is_expr_op = [](const Expr& expr, std::string op_type) { if (const auto* call_node = expr.as()) { if (const auto* op_node = call_node->op.as()) { From 89a89ba60032363b2a912a782f78b30cd9893716 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 20 Apr 2020 15:26:42 -0700 Subject: [PATCH 31/46] initial partitioner --- python/tvm/relay/df_pattern/__init__.py | 24 +++ src/relay/ir/dataflow_functor.cc | 1 + src/relay/ir/dataflow_matcher.cc | 201 ++++++++++++++++++++---- tests/python/relay/test_df_pattern.py | 137 +++++++++++++--- 4 files changed, 308 insertions(+), 55 deletions(-) diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 2280592eeb05..41cb9c7451e7 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -96,6 +96,17 @@ def match(self, expr: Expr) -> bool: """ return match(self, expr) + def partition(self, expr: Expr) -> bool: + """ + Parition the expression into functions defined by this pattern + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + """ + return partition(self, expr) + def dominates(self, parent, path=None): """ Create a dominator for this partern @@ -462,3 +473,16 @@ def rewrite(callbacks, expr: Expr) -> Expr: tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) return ffi.rewrite(tmp, expr) + +def partition(pattern: DFPattern, expr: Expr) -> Expr: + """ + Rewrite expression with the given callbacks + + Parameters + ---------- + partion: tvm.relay.df_pattern.DFPattern + The pattern to separate into functions + expr : tvm.relay.Expr + The expression to rewrite. + """ + return ffi.partition(pattern, expr) diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_functor.cc index c367b228d4b0..15d41475d0da 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_functor.cc @@ -268,6 +268,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { auto current = graph_.node_map_[pattern]; if (parent) { current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); } } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index aef6edc3db17..c9e24e4f8ae2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -55,10 +56,8 @@ class DFPatternMatcher : public DFPatternFunctor FindDominated(const DFPattern& node); - bool FindParent(const Expr& expr, - const std::unordered_set& dominated_exprs, - const DominatorPatternNode* op); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); std::unordered_map, ObjectHash, ObjectEqual> memo_; std::vector matched_nodes_; @@ -257,36 +256,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } -std::unordered_set DFPatternMatcher::FindDominated( - const DFPattern& node) { - std::unordered_set dominated_exprs; - auto indexed_node = pattern_graph_.node_map_[node]; - for (auto dominated : indexed_node->dominator_children_) { - if (dominated->ref_.as()) { - continue; - } - if (memo_.count(dominated->ref_)) { - Array matched = memo_[dominated->ref_]; - dominated_exprs.insert(matched[matched.size() - 1]); - } - } - return dominated_exprs; -} - -bool DFPatternMatcher::FindParent( - const Expr& expr, const std::unordered_set& dominated_exprs, - const DominatorPatternNode* op) { +bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { bool out = true; - for (auto node : expr_graph_.node_map_[expr]->dominator_children_) { - if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as() == nullptr) { + auto call_node = expr.as(); + for (auto node : expr_graph_.node_map_[expr]->inputs_) { + if (!(call_node && node->ref_ == call_node->op)) { memoize_ = true; if (VisitDFPattern(op->parent, node->ref_)) { return true; } else { memoize_ = false; if (VisitDFPattern(op->path, node->ref_)) { - auto new_dominated_exprs = FindDominated(op->path); - out &= FindParent(node->ref_, new_dominated_exprs, op); + out &= MatchesPath(op, node->ref_); } else { return false; } @@ -296,14 +277,35 @@ bool DFPatternMatcher::FindParent( return out; } +bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { + std::stack stack; + std::unordered_set visited; + stack.push(expr); + while (!stack.empty()) { + Expr current = stack.top(); + stack.pop(); + for (auto node : expr_graph_.node_map_[current]->dominator_children_) { + if (visited.count(node->ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + stack.push(node->ref_); + } + visited.insert(node->ref_); + } + } + } + return false; +} + bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { pattern_graph_ = CreateIndexedGraph(GetRef(op)); if (VisitDFPattern(op->child, expr)) { - memoize_ = false; - auto dominated_exprs = FindDominated(op->child); - bool matches = FindParent(expr, dominated_exprs, op); + bool matches_path = MatchesPath(op, expr); memoize_ = true; - return matches; + if (matches_path) { + return DominatesParent(op, expr); + } } return false; } @@ -415,5 +417,142 @@ Expr RewritePatterns(Array callbacks, Expr expr) { TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); +class PatternPartitioner : protected MixedModeMutator { + public: + Expr Partition(const DFPattern& pattern, const Expr& pre) { + pattern_ = pattern; + auto matcher = DFPatternMatcher(pre); + matcher_ = &matcher; + partitioning_ = true; + memo_.clear(); + this->VisitExpr(pre); + memo_.clear(); + partitioning_ = false; + return this->VisitExpr(pre); + } + + protected: + struct Group { + Expr root_node; + int gid; + Array args; + Function function; + }; + class PartitionRewriter : public ExprMutator { + public: + PartitionRewriter(const std::unordered_map& inputs) + : inputs_(inputs) {} + const std::unordered_map& GetMemo() { return this->memo_; } + + protected: + Expr VisitExpr(const Expr& pre) override { + if (inputs_.count(pre)) { + return inputs_.at(pre); + } + return ExprMutator::VisitExpr(pre); + } + const std::unordered_map inputs_; + }; + + void CreateGroup(const Expr& expr) { + var_number_ = 0; + + auto node_map = matcher_->GetMemo(); + auto pattern_graph = CreateIndexedGraph(pattern_); + // Get fuzzy patterns + std::unordered_set fuzzy_matches; + for (auto node : pattern_graph.topological_order_) { + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : {op->parent, op->path}) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } + } + } + } + // Create input variables + Group group; + group.root_node = expr; + std::unordered_map inputs; + Array params; + for (auto node : pattern_graph.topological_order_) { + if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && + match.as() == nullptr) { + inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + + std::to_string(var_number_), + NullValue()); + group.args.push_back(match); + params.push_back(inputs[match]); + var_number_++; + } + } + } + } + } + graph_number_++; + + // Make the new function + auto rewriter = PartitionRewriter(inputs); + auto body = rewriter.Mutate(expr); + CHECK(DFPatternMatcher(body).Match(pattern_, body)); + group.function = Function(params, body, NullValue(), Array()); + // Check to make sure we aren't overlapping with another group + for (auto kv : rewriter.GetMemo()) { + if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 && + kv.first.as() == nullptr && kv.first.as() == nullptr) { + // Exit due to overlapping partitions + return; + } + } + group.gid = ++gid_; + for (auto kv : rewriter.GetMemo()) { + gid_assignments_[kv.first] = gid_; + } + groups_.emplace_back(std::move(group)); + CHECK_EQ(groups_[gid_].gid, gid_); + } + Expr RewriteParition(const Group& group) { + Array args; + for (size_t i = 0; i < group.args.size(); ++i) { + args.push_back(memo_[group.args[i]]); + } + return Call(group.function, args); + } + + Expr DispatchVisitExpr(const Expr& pre) override { + Expr post = pre; + if (partitioning_) { + if (matcher_->Match(pattern_, pre)) { + CreateGroup(post); + } + } else { + post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + post = RewriteParition(groups_[gid_assignments_[pre]]); + } + } + return post; + } + + DFPatternMatcher* matcher_ = nullptr; + DFPattern pattern_; + std::unordered_map gid_assignments_; + std::vector groups_{Group()}; + bool partitioning_ = true; + int var_number_ = 0; + int graph_number_ = 0; + int gid_ = 0; +}; + +Expr PartitionPattern(DFPattern pattern, Expr expr) { + return PatternPartitioner().Partition(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relay.df_pattern.partition").set_body_typed(PartitionPattern); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index cf25a19e67b8..d6258d594da8 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -548,29 +548,118 @@ def test_algebraic_simplify(): assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) +def test_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp*inp, weight*weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + print(str(diamond.partition(out))) + +def test_quadruple_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + + inp = relay.var('input') + weight = relay.var('weight') + # Classic Diamond + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Deeper Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Single Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Fuzzy path/nested Diamond + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + print(str(diamond.partition(out))) + +def test_parition_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + print(str(BatchnormCallback().pattern.partition(BN))) + +def test_parition_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + print(str(BatchnormCallback().pattern.partition(BN2))) + if __name__ == "__main__": - #test_match_op() - #test_no_match_op() - #test_match_op_or() - #test_match_call() - #test_no_match_call() - #test_match_call_commutive() - #test_no_match_call_commutive() - #test_match_tuple() - #test_no_match_tuple() - #test_match_type() - #test_no_match_type() - #test_match_attr() - #test_no_match_attr() - #test_match_diamond() - #test_no_match_diamond() - #test_match_fake_diamond() - #test_rewrite() - #test_fuse_batchnorm() - #test_no_fuse_batchnorm() - #test_fuse_double_batchnorm() - #test_partial_fuse_double_batchnorm() - #test_fuse_batchnorm_commutation() - #test_match_dominator() - #test_not_match_dominator() + test_match_op() + test_no_match_op() + test_match_op_or() + test_match_call() + test_no_match_call() + test_match_call_commutive() + test_no_match_call_commutive() + test_match_tuple() + test_no_match_tuple() + test_match_type() + test_no_match_type() + test_match_attr() + test_no_match_attr() + test_match_diamond() + test_no_match_diamond() + test_match_fake_diamond() + test_rewrite() + test_fuse_batchnorm() + test_no_fuse_batchnorm() + test_fuse_double_batchnorm() + test_partial_fuse_double_batchnorm() + test_fuse_batchnorm_commutation() + test_match_dominator() + test_not_match_dominator() test_algebraic_simplify() + test_partition_dominator() + test_quadruple_partition_dominator() + test_parition_batchnorm() + test_parition_double_batchnorm() + From 0831f56e323b19ac236e4b95578433c9916e3a16 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 22 Apr 2020 10:13:17 -0700 Subject: [PATCH 32/46] refactor rewriter to handle back to back dominator patterns --- python/tvm/relay/df_pattern/__init__.py | 6 +- src/relay/ir/dataflow_matcher.cc | 266 +++++++++++++++--------- tests/python/relay/test_df_pattern.py | 186 ++++++++++++++--- 3 files changed, 326 insertions(+), 132 deletions(-) diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 41cb9c7451e7..79f35fdb436a 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -476,13 +476,13 @@ def rewrite(callbacks, expr: Expr) -> Expr: def partition(pattern: DFPattern, expr: Expr) -> Expr: """ - Rewrite expression with the given callbacks + Parition the expression into a series of functions that match the pattern Parameters ---------- partion: tvm.relay.df_pattern.DFPattern - The pattern to separate into functions + The pattern to match expr : tvm.relay.Expr - The expression to rewrite. + The expression to split into functions """ return ffi.partition(pattern, expr) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c9e24e4f8ae2..4f6679e22daf 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -256,6 +256,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex return false; } +// Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { bool out = true; auto call_node = expr.as(); @@ -277,6 +278,7 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e return out; } +// Iteratively ensure that the parent is dominated somewhere by the child or the path bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { std::stack stack; std::unordered_set visited; @@ -362,85 +364,57 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern patter return DFPatternMatcher(expr).Match(pattern, expr); }); -// Rewrite - -DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) { - ObjectPtr n = make_object(); - n->pattern_ = std::move(pattern); - n->function_ = std::move(function); - return DFPatternCallback(n); -} - -TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); - -TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") - .set_body_typed(DFPatternCallbackNode::make); - -class PatternRewriter : protected MixedModeMutator { +/* \brief PatternGrouper does pre-rewriting pattern matching and analysis + * + * This class creates a number of groups of matched expressions, ensures they don't overlap, and + * returns them to the caller for post-analysis rewriting. + * + * This is primarly needed to suppor the post-dominator analysis required for dominator pattern + * matching. + */ +class PatternGrouper : protected MixedModeVisitor { public: - PatternRewriter() {} - Expr Rewrite(const Array& callbacks, const Expr& pre) { - auto post = pre; - auto last = post; - // rewrite the graph until it stops changing to make sure all rewrites are complete - do { - last = post; - for (auto callback : callbacks) { - callback_ = &callback; - auto matcher = DFPatternMatcher(post); - matcher_ = &matcher; - memo_.clear(); - post = this->VisitExpr(post); - } - } while (last != post); - return post; - } - - protected: - Expr DispatchVisitExpr(const Expr& pre) override { - auto post = MixedModeMutator::DispatchVisitExpr(pre); - if (auto* callback_node = callback_->as()) { - if (matcher_->Match(callback_node->pattern_, post)) { - return callback_node->function_(pre, post, matcher_->GetMemo()); - } - } - return post; - } + /* \brief Internal Group class for storing analysis */ + struct Group { + Expr root_node; + int gid; + Map> matched_nodes; + Function function; + Array args; + }; - DFPatternMatcher* matcher_ = nullptr; - DFPatternCallback* callback_ = nullptr; -}; + /* \brief Return the discovered groups */ + const std::vector& GetGroups() { return this->groups_; } -Expr RewritePatterns(Array callbacks, Expr expr) { - return PatternRewriter().Rewrite(callbacks, expr); -} - -TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); + /* \brief Return the group assingnments of expressions */ + const std::unordered_map& GetGIDAssignments() { + return gid_assignments_; + } + /* \brief Group expressions that match the pattern */ + void GroupMatches(const DFPattern& pattern, const Expr& pre) { + groups_ = {Group()}; + gid_assignments_.clear(); + visit_counter_.clear(); -class PatternPartitioner : protected MixedModeMutator { - public: - Expr Partition(const DFPattern& pattern, const Expr& pre) { pattern_ = pattern; + pattern_graph_ = CreateIndexedGraph(pattern_); auto matcher = DFPatternMatcher(pre); matcher_ = &matcher; - partitioning_ = true; - memo_.clear(); this->VisitExpr(pre); - memo_.clear(); - partitioning_ = false; - return this->VisitExpr(pre); } protected: - struct Group { - Expr root_node; - int gid; - Array args; - Function function; - }; - class PartitionRewriter : public ExprMutator { + void VisitLeaf(const Expr& pre) override { + if (matcher_->Match(pattern_, pre)) { + CreateGroup(pre); + } + } + + /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + * group overlap analysis */ + class MatchExtractor : public ExprMutator { public: - PartitionRewriter(const std::unordered_map& inputs) + explicit MatchExtractor(const std::unordered_map& inputs) : inputs_(inputs) {} const std::unordered_map& GetMemo() { return this->memo_; } @@ -454,14 +428,15 @@ class PatternPartitioner : protected MixedModeMutator { const std::unordered_map inputs_; }; + /* \brief Create a group based on a matched expression */ void CreateGroup(const Expr& expr) { var_number_ = 0; auto node_map = matcher_->GetMemo(); - auto pattern_graph = CreateIndexedGraph(pattern_); + // Get fuzzy patterns std::unordered_set fuzzy_matches; - for (auto node : pattern_graph.topological_order_) { + for (auto node : pattern_graph_.topological_order_) { if (auto op = node->ref_.as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { @@ -470,18 +445,21 @@ class PatternPartitioner : protected MixedModeMutator { } } } + // Create input variables Group group; group.root_node = expr; + group.matched_nodes = node_map; + std::unordered_map inputs; Array params; - for (auto node : pattern_graph.topological_order_) { + for (auto node : pattern_graph_.topological_order_) { if (node->inputs_.size() == 0) { if (node_map.count(node->ref_)) { auto matches = node_map[node->ref_]; for (auto match : matches) { if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && - match.as() == nullptr) { + match.as() == nullptr && match.as() == nullptr) { inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number_), NullValue()); @@ -493,29 +471,140 @@ class PatternPartitioner : protected MixedModeMutator { } } } + graph_number_++; - // Make the new function - auto rewriter = PartitionRewriter(inputs); - auto body = rewriter.Mutate(expr); + // Extract a Function. Used in Parition directly, + // used to determine Group overlap in other passes + auto extractor = MatchExtractor(inputs); + auto body = extractor.Mutate(expr); + + // Verify the pattern still holds CHECK(DFPatternMatcher(body).Match(pattern_, body)); group.function = Function(params, body, NullValue(), Array()); + // Check to make sure we aren't overlapping with another group - for (auto kv : rewriter.GetMemo()) { + for (auto kv : extractor.GetMemo()) { if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 && - kv.first.as() == nullptr && kv.first.as() == nullptr) { + kv.first.as() == nullptr && kv.first.as() == nullptr && + kv.first.as() == nullptr) { // Exit due to overlapping partitions return; } } + // Assign Group Ids group.gid = ++gid_; - for (auto kv : rewriter.GetMemo()) { + for (auto kv : extractor.GetMemo()) { gid_assignments_[kv.first] = gid_; } + + // Save Group groups_.emplace_back(std::move(group)); CHECK_EQ(groups_[gid_].gid, gid_); } - Expr RewriteParition(const Group& group) { + + // Internal State + DFPattern pattern_; + std::vector groups_; + std::unordered_map gid_assignments_; + DFPatternMatcher* matcher_ = nullptr; + IndexedGraph pattern_graph_; + int gid_ = 0; + int var_number_ = 0; + int graph_number_ = 0; +}; + +// Rewrite + +DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) { + ObjectPtr n = make_object(); + n->pattern_ = std::move(pattern); + n->function_ = std::move(function); + return DFPatternCallback(n); +} + +TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); + +TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") + .set_body_typed(DFPatternCallbackNode::make); + +/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback + * function to rewrtie those matches + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternRewriter : protected MixedModeMutator { + public: + PatternRewriter() {} + /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the + * callbacks until it stops changing */ + Expr Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + do { + last = post; + for (auto callback : callbacks) { + callback_ = callback; + auto grouper = PatternGrouper(); + grouper.GroupMatches(callback_->pattern_, post); + groups_ = grouper.GetGroups(); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + post = this->VisitExpr(post); + } + } while (last != post); + return post; + } + + protected: + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + // Convert the pre-rewrite node map to a post-rewrite node map + auto group = groups_[gid_assignments_[pre]]; + std::unordered_map, ObjectHash, ObjectEqual> node_map; + for (auto kv : group.matched_nodes) { + Array tmp; + for (size_t i = 0; i < kv.second.size(); ++i) { + tmp.push_back(this->memo_[kv.second[i]]); + } + node_map.insert({kv.first, tmp}); + } + // run the user callback function + return callback_->function_(pre, post, Map>(node_map)); + } + return post; + } + + DFPatternCallback callback_; + std::vector groups_; + std::unordered_map gid_assignments_; +}; + +Expr RewritePatterns(Array callbacks, Expr expr) { + return PatternRewriter().Rewrite(callbacks, expr); +} + +TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); + +/* \brief PatternParitioner replaces expressions that match a pattern with function call that + * perform the same computation but allow for further analysis and lowering. + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternPartitioner : protected MixedModeMutator { + public: + Expr Partition(const DFPattern& pattern, const Expr& pre) { + auto grouper = PatternGrouper(); + grouper.GroupMatches(pattern, pre); + groups_ = grouper.GetGroups(); + gid_assignments_ = grouper.GetGIDAssignments(); + return this->VisitExpr(pre); + } + + protected: + Expr RewriteParition(const PatternGrouper::Group& group) { Array args; for (size_t i = 0; i < group.args.size(); ++i) { args.push_back(memo_[group.args[i]]); @@ -524,28 +613,15 @@ class PatternPartitioner : protected MixedModeMutator { } Expr DispatchVisitExpr(const Expr& pre) override { - Expr post = pre; - if (partitioning_) { - if (matcher_->Match(pattern_, pre)) { - CreateGroup(post); - } - } else { - post = MixedModeMutator::DispatchVisitExpr(pre); - if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { - post = RewriteParition(groups_[gid_assignments_[pre]]); - } + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + post = RewriteParition(groups_[gid_assignments_[pre]]); } return post; } - DFPatternMatcher* matcher_ = nullptr; - DFPattern pattern_; + std::vector groups_; std::unordered_map gid_assignments_; - std::vector groups_{Group()}; - bool partitioning_ = true; - int var_number_ = 0; - int graph_number_ = 0; - int gid_ = 0; }; Expr PartitionPattern(DFPattern pattern, Expr expr) { diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index d6258d594da8..c7bf5356079a 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -472,6 +472,61 @@ def test_fuse_batchnorm_commutation(): out = rewrite(BatchnormCallback(), BN) assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) +def test_quadruple_rewrite_dominator(): + class DominatorRemovalCallback(DFPatternCallback): + def __init__(self): + self.inp = wildcard() + self.weight = wildcard() + + is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction) + + def callback(self, pre, post, node_map): + inp = node_map[self.inp][0] + weight = node_map[self.weight][0] + return relay.op.nn.conv2d(inp, weight) + + inp = relay.var('input') + weight = relay.var('weight') + # Classic Diamond + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Deeper Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Single Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Fuzzy path/nested Diamond + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + one = relay.op.nn.conv2d(inp, weight) + two = relay.op.nn.conv2d(one, weight) + three = relay.op.nn.conv2d(two, weight) + four = relay.op.nn.conv2d(three, weight) + + assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four) + def algebraic_simplify(expr): zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) @@ -558,14 +613,20 @@ def test_partition_dominator(): # Classic Diamond inp = relay.var('input') weight = relay.var('weight') - conv2d = relay.op.nn.conv2d(inp*inp, weight*weight) - relu = relay.op.nn.relu(conv2d) - relu = relay.op.nn.relu(relu) - leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) - out = relu + leaky_relu - + def generate_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + out = generate_diamond(inp*inp, weight*weight) # Check - print(str(diamond.partition(out))) + partitioned = diamond.partition(out) + + i = relay.Var("input") + w = relay.Var("weight") + f = relay.Function([i, w], generate_diamond(i, w)) + assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight)) def test_quadruple_partition_dominator(): # Pattern @@ -578,36 +639,68 @@ def test_quadruple_partition_dominator(): inp = relay.var('input') weight = relay.var('weight') # Classic Diamond - conv2d = relay.op.nn.conv2d(inp, weight) - relu = relay.op.nn.relu(conv2d) - relu = relay.op.nn.relu(relu) - leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) - out = relu + leaky_relu + def classic_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu # Deeper Branch - conv2d = relay.op.nn.conv2d(out, weight) - relu = relay.op.nn.relu(conv2d) - relu = relay.op.nn.relu(relu) - relu = relay.op.tanh(relu) - leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) - out = relu + leaky_relu + def deeper_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu # Single Branch - conv2d = relay.op.nn.conv2d(out, weight) - relu = relay.op.nn.relu(conv2d) - relu = relay.op.nn.relu(relu) - tanh = relay.op.tanh(relu) - out = relu + tanh + def single_branch(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + return relu + tanh # Fuzzy path/nested Diamond - conv2d = relay.op.nn.conv2d(out, weight) - relu = relay.op.nn.relu(conv2d) - relu = relu + relu - tanh = relay.op.tanh(relu) - leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) - out = tanh + leaky_relu - - print(str(diamond.partition(out))) + def nested_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return tanh + leaky_relu + + partitioned = diamond.partition( + nested_diamond( + single_branch( + deeper_diamond( + classic_diamond(inp, weight), + weight), + weight), + weight + ) + ) + + functions = [] + for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]: + inpf = relay.var("input") + weightf = relay.var("weight") + functions.append(relay.Function([inpf, weightf], f(inpf, weightf))) + + reference = functions[3]( + functions[2]( + functions[1]( + functions[0](inp, weight), + weight), + weight), + weight + ) + assert tvm.ir.structural_equal(partitioned, reference) + +def get_BN(x, var, mean, beta, gamma, eps = 1e-5): + return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta def test_parition_batchnorm(): x = relay.var('x') @@ -615,10 +708,19 @@ def test_parition_batchnorm(): mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') + BN = get_BN(x, var, mean, beta, gamma) + - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + # Put the arguments in toplogological order for the reference + f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) - print(str(BatchnormCallback().pattern.partition(BN))) + partitioned = BatchnormCallback().pattern.partition(BN) + assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta)) def test_parition_double_batchnorm(): x = relay.var('x') @@ -630,7 +732,23 @@ def test_parition_double_batchnorm(): BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - print(str(BatchnormCallback().pattern.partition(BN2))) + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) + # The paritioner doesn't replace duplicates, so we use two copies of the function + xf2 = relay.var('xf2') + varf2 = relay.var('varf2') + meanf2 = relay.var('meanf2') + betaf2 = relay.var('betaf2') + gammaf2 = relay.var('gammaf2') + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)) + + partitioned = BatchnormCallback().pattern.partition(BN2) + reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) + assert tvm.ir.structural_equal(partitioned, reference) if __name__ == "__main__": test_match_op() From 44eb6dd359a42e5280dbadac355bcbe91d9a9d37 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 24 Apr 2020 10:54:17 -0700 Subject: [PATCH 33/46] respond to review comments --- docs/langref/relay_pattern.rst | 2 +- include/tvm/relay/dataflow_functor.h | 2 +- include/tvm/relay/dataflow_pattern.h | 6 +- python/tvm/relay/df_pattern/__init__.py | 75 +++++++++++++++++++++++++ src/relay/ir/dataflow_matcher.cc | 30 +++++++--- src/relay/ir/dataflow_pattern.cc | 2 +- tests/python/relay/test_df_pattern.py | 33 +++++++---- 7 files changed, 125 insertions(+), 25 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index cf47e9ca51d9..cc4c80e79e57 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -24,7 +24,7 @@ There are many places in TVM where we identify pure data-flow sub-graphs of the Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match. -Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parametrized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform. +Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parameterized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform. In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box. diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_functor.h index a0edeb224f62..c405004750a4 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_functor.h @@ -37,7 +37,7 @@ namespace relay { /*! * \brief A dynamical functor that dispatches on in the first DFPattern argument. * - * \tparam FType function signiture + * \tparam FType function signature * This type is only defined for FType with function signature R(const DFPattern&, * Args...) */ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index b151c13ed0be..25cd7481e188 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -91,7 +91,7 @@ class VarPatternNode : public DFPatternNode { */ std::string name; /*! - * \brief type annotaion of the variable. + * \brief type annotation of the variable. * This field records user provided type annotation of the Var. * This field is optional and can be None. */ @@ -103,6 +103,7 @@ class VarPatternNode : public DFPatternNode { } void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); v->Visit("type_annotation", &type_annotation); } @@ -212,7 +213,8 @@ class TupleGetItemPatternNode : public DFPatternNode { void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tuple_value", &tuple); + v->Visit("tuple", &tuple); + v->Visit("index", &index); } TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index); diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/df_pattern/__init__.py index 79f35fdb436a..640c5148f1b6 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/df_pattern/__init__.py @@ -70,6 +70,11 @@ def has_attr(self, attr_name: str, attr_value): The name of the attribute to match attr_value: Any The value of the attribute to match + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting AttrPattern """ attrs = make_node("DictAttrs", **{attr_name: attr_value}) return AttrPattern(self, attrs) @@ -82,6 +87,11 @@ def has_type(self, ttype): ---------- ttype: tvm.relay.Type The type to match + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting TypePattern """ return has_type(ttype, self) @@ -93,6 +103,11 @@ def match(self, expr: Expr) -> bool: ---------- expr : tvm.relay.Expr The expression to match. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern """ return match(self, expr) @@ -104,6 +119,11 @@ def partition(self, expr: Expr) -> bool: ---------- expr : tvm.relay.Expr The expression to match. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph """ return partition(self, expr) @@ -117,6 +137,11 @@ def dominates(self, parent, path=None): The parent pattern this pattern dominates. path: tvm.relay.df_pattern.DFPattern The fuzzy path pattern. + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting DominatorPattern """ if path is None: path = wildcard() @@ -131,6 +156,11 @@ def is_input(name: str = "") -> DFPattern: ---------- name: str The name of the input pattern to match + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting InputPattern """ return VarPattern(name) @@ -143,6 +173,11 @@ def is_op(op_name: str) -> DFPattern: ---------- op_name: String The name of the relay op + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting ExprPattern """ op = get(op_name) return ExprPattern(op) @@ -151,6 +186,11 @@ def is_op(op_name: str) -> DFPattern: def wildcard() -> DFPattern: """ Syntatic sugar for creating a WildcardPattern + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting WildcardPattern """ return WildcardPattern() @@ -166,6 +206,11 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: ttype: tvm.relay.Type The type to match + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting TypePattern """ if pattern is None: pattern = wildcard() @@ -183,6 +228,11 @@ def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: attrs: tvm.Attrs The attributes to match + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting AttrPattern """ if pattern is None: pattern = wildcard() @@ -201,6 +251,11 @@ def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern The fuzzy path pattern. child: tvm.relay.df_pattern.DFPattern The child pattern. + + Returns + ------- + result: tvm.relay.df_pattern.DFPattern + The resulting DominatorPattern """ return DominatorPattern(parent, path, child) @@ -429,6 +484,11 @@ def rewrite(self, expr: Expr) -> Expr: ---------- expr : tvm.relay.Expr The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks """ return rewrite(self, expr) @@ -444,6 +504,11 @@ def callback(self, pre, post, node_map): The matching expression with rewritten inputs node_map : Map(DFPattern, List(Expr)) The map between patterns and matched expressions + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraph rewritten by the callback """ raise "Unimplemented" @@ -464,6 +529,11 @@ def rewrite(callbacks, expr: Expr) -> Expr: The input callback or list of callbacks. expr : tvm.relay.Expr The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks """ if isinstance(callbacks, DFPatternCallback): tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] @@ -484,5 +554,10 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr: The pattern to match expr : tvm.relay.Expr The expression to split into functions + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph """ return ffi.partition(pattern, expr) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 4f6679e22daf..8063d0877f76 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -369,7 +369,7 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern patter * This class creates a number of groups of matched expressions, ensures they don't overlap, and * returns them to the caller for post-analysis rewriting. * - * This is primarly needed to suppor the post-dominator analysis required for dominator pattern + * This is primarily needed to suppor the post-dominator analysis required for dominator pattern * matching. */ class PatternGrouper : protected MixedModeVisitor { @@ -386,7 +386,7 @@ class PatternGrouper : protected MixedModeVisitor { /* \brief Return the discovered groups */ const std::vector& GetGroups() { return this->groups_; } - /* \brief Return the group assingnments of expressions */ + /* \brief Return the group assignments of expressions */ const std::unordered_map& GetGIDAssignments() { return gid_assignments_; } @@ -474,7 +474,7 @@ class PatternGrouper : protected MixedModeVisitor { graph_number_++; - // Extract a Function. Used in Parition directly, + // Extract a Function. Used in Partition directly, // used to determine Group overlap in other passes auto extractor = MatchExtractor(inputs); auto body = extractor.Mutate(expr); @@ -484,6 +484,13 @@ class PatternGrouper : protected MixedModeVisitor { group.function = Function(params, body, NullValue(), Array()); // Check to make sure we aren't overlapping with another group + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the pattern + // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a + // situation where we try to rewrite the same node twice in the second rewriting or parition + // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants + // because they exist more globally outside of the fusion. for (auto kv : extractor.GetMemo()) { if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 && kv.first.as() == nullptr && kv.first.as() == nullptr && @@ -529,19 +536,20 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") .set_body_typed(DFPatternCallbackNode::make); /* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback - * function to rewrtie those matches + * function to rewrite those matches * * The class uses PatternGrouper to support the dominator pattern. */ class PatternRewriter : protected MixedModeMutator { public: PatternRewriter() {} - /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the + /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the * callbacks until it stops changing */ Expr Rewrite(const Array& callbacks, const Expr& pre) { auto post = pre; auto last = post; // rewrite the graph until it stops changing to make sure all rewrites are complete + int count = 0; do { last = post; for (auto callback : callbacks) { @@ -552,8 +560,12 @@ class PatternRewriter : protected MixedModeMutator { gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); post = this->VisitExpr(post); + count++; } - } while (last != post); + } while (last != post || count >= 100); + if (count >= 100) { + throw("Observed 100 rewrite passes, possible conflicting passes?"); + } return post; } @@ -588,7 +600,7 @@ Expr RewritePatterns(Array callbacks, Expr expr) { TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); -/* \brief PatternParitioner replaces expressions that match a pattern with function call that +/* \brief PatternPartitioner replaces expressions that match a pattern with function call that * perform the same computation but allow for further analysis and lowering. * * The class uses PatternGrouper to support the dominator pattern. @@ -604,7 +616,7 @@ class PatternPartitioner : protected MixedModeMutator { } protected: - Expr RewriteParition(const PatternGrouper::Group& group) { + Expr RewritePartition(const PatternGrouper::Group& group) { Array args; for (size_t i = 0; i < group.args.size(); ++i) { args.push_back(memo_[group.args[i]]); @@ -615,7 +627,7 @@ class PatternPartitioner : protected MixedModeMutator { Expr DispatchVisitExpr(const Expr& pre) override { auto post = MixedModeMutator::DispatchVisitExpr(pre); if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { - post = RewriteParition(groups_[gid_assignments_[pre]]); + post = RewritePartition(groups_[gid_assignments_[pre]]); } return post; } diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 8568205541dc..25e4f0b472c4 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -79,7 +79,7 @@ CallPattern CallPatternNode::make(DFPattern op, Array args, Attrs att return CallPattern(n); } -TVM_REGISTER_NODE_TYPE(CallNode); +TVM_REGISTER_NODE_TYPE(CallPatternNode); TVM_REGISTER_GLOBAL("relay.df_pattern.CallPattern") .set_body_typed(CallPatternNode::make); diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index c7bf5356079a..c12c7989a7a0 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -28,47 +28,58 @@ ## NODE TESTS def test_expr_pattern(): ep = ExprPattern(relay.var('x', shape=(4, 1))) - print(ep) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, relay.Var) def test_var_pattern(): v = is_input("x") - print(v) + assert isinstance(v, VarPattern) + assert v.name == "x" def test_wildcard_pattern(): wc = wildcard() - print(wc) + assert isinstance(wc, WildcardPattern) def test_CallPattern(): wc1 = wildcard() wc2 = wildcard() c = is_op("add")(wc1, wc2) - print(c) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) def test_TuplePattern(): wc1 = wildcard() wc2 = wildcard() t = TuplePattern([wc1, wc2]) - print(t) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], WildcardPattern) def test_TupleGetItemPattern(): wc1 = wildcard() wc2 = wildcard() t = TuplePattern([wc1, wc2]) tgi = TupleGetItemPattern(t, 1) - print(tgi) + assert isinstance(tgi, TupleGetItemPattern) + assert isinstance(tgi.tuple, TuplePattern) + assert isinstance(tgi.tuple.fields[0], WildcardPattern) + assert isinstance(tgi.tuple.fields[1], WildcardPattern) def test_AltPattern(): is_add_or_sub = is_op('add') | is_op('subtract') - print(is_add_or_sub) + assert isinstance(is_add_or_sub, AltPattern) def test_TypePattern(): - ty_pat = has_type(relay.TensorType((10, 10), "float32")) - print(ty_pat) + ttype = relay.TensorType((10, 10), "float32") + ty_pat = has_type(ttype) + assert isinstance(ty_pat, TypePattern) + assert ty_pat.type == ttype def test_AttrPattern(): op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) - op_pat = op(wildcard(), wildcard()) - print(op_pat) + assert isinstance(op, AttrPattern) + assert op.attrs["TOpPattern"] == K_ELEMWISE ## MATCHER TESTS From 49d4ab976934bc77da18ab6ae5030f6cd683804f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 24 Apr 2020 11:17:31 -0700 Subject: [PATCH 34/46] respond to review comments --- include/tvm/relay/dataflow_matcher.h | 2 +- include/tvm/relay/dataflow_pattern.h | 20 +++--- ...w_functor.h => dataflow_pattern_functor.h} | 6 +- .../__init__.py | 62 +++++++++---------- .../{df_pattern => dataflow_pattern}/_ffi.py | 2 +- src/relay/ir/dataflow_matcher.cc | 22 +++---- src/relay/ir/dataflow_pattern.cc | 21 ++++--- ...functor.cc => dataflow_pattern_functor.cc} | 2 +- ...df_pattern.py => test_dataflow_pattern.py} | 2 +- 9 files changed, 70 insertions(+), 69 deletions(-) rename include/tvm/relay/{dataflow_functor.h => dataflow_pattern_functor.h} (98%) rename python/tvm/relay/{df_pattern => dataflow_pattern}/__init__.py (89%) rename python/tvm/relay/{df_pattern => dataflow_pattern}/_ffi.py (93%) rename src/relay/ir/{dataflow_functor.cc => dataflow_pattern_functor.cc} (99%) rename tests/python/relay/{test_df_pattern.py => test_dataflow_pattern.py} (99%) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 9bf7e15b0933..59ff9cd2776a 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ #define TVM_RELAY_DATAFLOW_MATCHER_H_ -#include #include +#include #include #include diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 25cd7481e188..f2ff7c9d71ea 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -62,7 +62,7 @@ class ExprPatternNode : public DFPatternNode { v->Visit("expr", &expr); } - static constexpr const char* _type_key = "relay.df_pattern.ExprPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); }; @@ -109,7 +109,7 @@ class VarPatternNode : public DFPatternNode { TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation); - static constexpr const char* _type_key = "relay.df_pattern.VarPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); }; @@ -170,7 +170,7 @@ class CallPatternNode : public DFPatternNode { TVM_DLL static CallPattern make(DFPattern op, Array args, Attrs attrs, Array type_args); - static constexpr const char* _type_key = "relay.df_pattern.CallPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); }; @@ -193,7 +193,7 @@ class TuplePatternNode : public DFPatternNode { TVM_DLL static TuplePattern make(tvm::Array fields); - static constexpr const char* _type_key = "relay.df_pattern.TuplePattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); }; @@ -219,7 +219,7 @@ class TupleGetItemPatternNode : public DFPatternNode { TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index); - static constexpr const char* _type_key = "relay.df_pattern.TupleGetItemPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); }; @@ -246,7 +246,7 @@ class AltPatternNode : public DFPatternNode { TVM_DLL static AltPattern make(DFPattern left, DFPattern right); - static constexpr const char* _type_key = "relay.df_pattern.AltPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); }; @@ -266,7 +266,7 @@ class WildcardPatternNode : public DFPatternNode { public: void VisitAttrs(tvm::AttrVisitor* v) {} - static constexpr const char* _type_key = "relay.df_pattern.WildcardPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); }; @@ -296,7 +296,7 @@ class TypePatternNode : public DFPatternNode { TVM_DLL static TypePattern make(DFPattern pattern, Type type); - static constexpr const char* _type_key = "relay.df_pattern.TypePattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); }; @@ -326,7 +326,7 @@ class AttrPatternNode : public DFPatternNode { TVM_DLL static AttrPattern make(DFPattern pattern, Attrs attrs); - static constexpr const char* _type_key = "relay.df_pattern.AttrPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); }; @@ -361,7 +361,7 @@ class DominatorPatternNode : public DFPatternNode { TVM_DLL static DominatorPattern make(DFPattern parent, DFPattern path, DFPattern child); - static constexpr const char* _type_key = "relay.df_pattern.DominatorPattern"; + static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); }; diff --git a/include/tvm/relay/dataflow_functor.h b/include/tvm/relay/dataflow_pattern_functor.h similarity index 98% rename from include/tvm/relay/dataflow_functor.h rename to include/tvm/relay/dataflow_pattern_functor.h index c405004750a4..82590ade8c68 100644 --- a/include/tvm/relay/dataflow_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -21,8 +21,8 @@ * \file tvm/relay/dataflow_matcher.h * \brief A pattern matcher for matching dataflow properties. */ -#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_ -#define TVM_RELAY_DATAFLOW_FUNCTOR_H_ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ #include #include @@ -245,4 +245,4 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_DATAFLOW_FUNCTOR_H_ +#endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/python/tvm/relay/df_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py similarity index 89% rename from python/tvm/relay/df_pattern/__init__.py rename to python/tvm/relay/dataflow_pattern/__init__.py index 640c5148f1b6..9a98ee2e9f0c 100644 --- a/python/tvm/relay/df_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -34,7 +34,7 @@ def register_df_node(type_key=None): """ if not isinstance(type_key, str): return tvm_ffi.register_object( - "relay.df_pattern." + type_key.__name__)(type_key) + "relay.dataflow_pattern." + type_key.__name__)(type_key) return tvm_ffi.register_object(type_key) @@ -73,7 +73,7 @@ def has_attr(self, attr_name: str, attr_value): Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting AttrPattern """ attrs = make_node("DictAttrs", **{attr_name: attr_value}) @@ -90,7 +90,7 @@ def has_type(self, ttype): Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting TypePattern """ return has_type(ttype, self) @@ -133,14 +133,14 @@ def dominates(self, parent, path=None): Parameters ---------- - parent: tvm.relay.df_pattern.DFPattern + parent: tvm.relay.dataflow_pattern.DFPattern The parent pattern this pattern dominates. - path: tvm.relay.df_pattern.DFPattern + path: tvm.relay.dataflow_pattern.DFPattern The fuzzy path pattern. Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting DominatorPattern """ if path is None: @@ -159,7 +159,7 @@ def is_input(name: str = "") -> DFPattern: Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting InputPattern """ return VarPattern(name) @@ -176,7 +176,7 @@ def is_op(op_name: str) -> DFPattern: Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting ExprPattern """ op = get(op_name) @@ -189,7 +189,7 @@ def wildcard() -> DFPattern: Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting WildcardPattern """ return WildcardPattern() @@ -201,7 +201,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: Parameters ---------- - pattern: tvm.relay.df_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The pattern that needs type annotation ttype: tvm.relay.Type @@ -209,7 +209,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting TypePattern """ if pattern is None: @@ -223,7 +223,7 @@ def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: Parameters ---------- - pattern: tvm.relay.df_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The input pattern. attrs: tvm.Attrs @@ -231,7 +231,7 @@ def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting AttrPattern """ if pattern is None: @@ -245,16 +245,16 @@ def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern Parameters ---------- - parent: tvm.relay.df_pattern.DFPattern + parent: tvm.relay.dataflow_pattern.DFPattern The parent pattern. - path: tvm.relay.df_pattern.DFPattern + path: tvm.relay.dataflow_pattern.DFPattern The fuzzy path pattern. - child: tvm.relay.df_pattern.DFPattern + child: tvm.relay.dataflow_pattern.DFPattern The child pattern. Returns ------- - result: tvm.relay.df_pattern.DFPattern + result: tvm.relay.dataflow_pattern.DFPattern The resulting DominatorPattern """ return DominatorPattern(parent, path, child) @@ -266,7 +266,7 @@ def match(pattern: DFPattern, expr: Expr) -> bool: Parameters ---------- - pattern: tvm.relay.df_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The input pattern. expr : tvm.relay.Expr The expression to match. @@ -317,10 +317,10 @@ class CallPattern(DFPattern): Parameters ---------- - op: realy.df_pattern.DFPattern + op: realy.dataflow_pattern.DFPattern The operation to be called. - args: List[realy.df_pattern.DFPattern] + args: List[realy.dataflow_pattern.DFPattern] The arguments to the call. attrs: Optional[tvm.Attrs] @@ -344,7 +344,7 @@ class TuplePattern(DFPattern): Parameters ---------- - fields : List[tvm.relay.df_pattern.DFPattern] + fields : List[tvm.relay.dataflow_pattern.DFPattern] The fields in the tuple. """ @@ -369,7 +369,7 @@ class TupleGetItemPattern(DFPattern): Parameters ---------- - tuple_value: tvm.relay.df_pattern.DFPattern + tuple_value: tvm.relay.dataflow_pattern.DFPattern The input tuple expression. index: int @@ -387,9 +387,9 @@ class AltPattern(DFPattern): Parameters ---------- - left: tvm.relay.df_pattern.DFPattern + left: tvm.relay.dataflow_pattern.DFPattern One possible matching Pattern - right: tvm.relay.df_pattern.DFPattern + right: tvm.relay.dataflow_pattern.DFPattern One possible matching Pattern """ @@ -413,7 +413,7 @@ class TypePattern(DFPattern): Parameters ---------- - pattern: tvm.relay.df_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The input pattern that needs type annotation ttype: tvm.relay.Type @@ -432,7 +432,7 @@ class AttrPattern(DFPattern): Parameters ---------- - pattern: tvm.relay.df_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The input pattern. attrs: tvm.Attrs @@ -450,13 +450,13 @@ class DominatorPattern(DFPattern): Parameters ---------- - parent: tvm.relay.df_pattern.DFPattern + parent: tvm.relay.dataflow_pattern.DFPattern The parent, i.e., the single node which produces something, later aggregated by the child - path: tvm.relay.df_pattern.DFPattern + path: tvm.relay.dataflow_pattern.DFPattern The fuzzy path pattern between parent and child, typically matches elementwise ops - child: tvm.relay.df_pattern.DFPattern + child: tvm.relay.dataflow_pattern.DFPattern The last node in the domination which is the end user for all nodes in the path and the parent """ @@ -525,7 +525,7 @@ def rewrite(callbacks, expr: Expr) -> Expr: Parameters ---------- - callbacks: tvm.relay.df_pattern.DFPatternCallback + callbacks: tvm.relay.dataflow_pattern.DFPatternCallback The input callback or list of callbacks. expr : tvm.relay.Expr The expression to rewrite. @@ -550,7 +550,7 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr: Parameters ---------- - partion: tvm.relay.df_pattern.DFPattern + partion: tvm.relay.dataflow_pattern.DFPattern The pattern to match expr : tvm.relay.Expr The expression to split into functions diff --git a/python/tvm/relay/df_pattern/_ffi.py b/python/tvm/relay/dataflow_pattern/_ffi.py similarity index 93% rename from python/tvm/relay/df_pattern/_ffi.py rename to python/tvm/relay/dataflow_pattern/_ffi.py index 2049f4217efb..b0a702c1d2f5 100644 --- a/python/tvm/relay/df_pattern/_ffi.py +++ b/python/tvm/relay/dataflow_pattern/_ffi.py @@ -17,4 +17,4 @@ """DataFlow Pattern Language FFI bindings.""" import tvm._ffi -tvm._ffi._init_api("relay.df_pattern", __name__) +tvm._ffi._init_api("relay.dataflow_pattern", __name__) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8063d0877f76..5ee6f095bfe2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -33,7 +33,6 @@ namespace relay { // Pattern Matcher - class DominatorMatcher; class DFPatternMatcher : public DFPatternFunctor { @@ -214,8 +213,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (is_pattern_op(op, "divide")) { if (const auto* arg_node = op->args[0].as()) { if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") && - (is_expr_op(call_node->args[0], "divide") || - is_expr_op(call_node->args[1], "divide"))) { + (is_expr_op(call_node->args[0], "divide") || + is_expr_op(call_node->args[1], "divide"))) { bool out = false; for (size_t arg_id = 0; arg_id < 2; ++arg_id) { auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]}, @@ -239,8 +238,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex for (size_t arg_id = 0; arg_id < 2; ++arg_id) { if (auto* arg_node = op->args[arg_id].as()) { if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && - (is_expr_op(call_node->args[0], "multiply") || - is_expr_op(call_node->args[1], "multiply"))) { + (is_expr_op(call_node->args[0], "multiply") || + is_expr_op(call_node->args[1], "multiply"))) { auto mul = CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, op->attrs, op->type_args); @@ -360,9 +359,10 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); -}); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match") + .set_body_typed([](DFPattern pattern, Expr expr) { + return DFPatternMatcher(expr).Match(pattern, expr); + }); /* \brief PatternGrouper does pre-rewriting pattern matching and analysis * @@ -532,7 +532,7 @@ DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc func TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") .set_body_typed(DFPatternCallbackNode::make); /* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback @@ -598,7 +598,7 @@ Expr RewritePatterns(Array callbacks, Expr expr) { return PatternRewriter().Rewrite(callbacks, expr); } -TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); /* \brief PatternPartitioner replaces expressions that match a pattern with function call that * perform the same computation but allow for further analysis and lowering. @@ -640,7 +640,7 @@ Expr PartitionPattern(DFPattern pattern, Expr expr) { return PatternPartitioner().Partition(pattern, expr); } -TVM_REGISTER_GLOBAL("relay.df_pattern.partition").set_body_typed(PartitionPattern); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 25e4f0b472c4..4dfda55f5d2d 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -34,7 +34,7 @@ ExprPattern::ExprPattern(Expr expr) { TVM_REGISTER_NODE_TYPE(ExprPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.ExprPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern") .set_body_typed([](Expr e) { return ExprPattern(e); }); @@ -55,7 +55,7 @@ VarPattern VarPatternNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.VarPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern") .set_body_typed(static_cast(VarPatternNode::make)); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -81,7 +81,7 @@ CallPattern CallPatternNode::make(DFPattern op, Array args, Attrs att TVM_REGISTER_NODE_TYPE(CallPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.CallPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") .set_body_typed(CallPatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -99,7 +99,7 @@ TuplePattern TuplePatternNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TuplePatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.TuplePattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern") .set_body_typed(TuplePatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -117,7 +117,7 @@ TupleGetItemPattern TupleGetItemPatternNode::make(DFPattern tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.TupleGetItemPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern") .set_body_typed(TupleGetItemPatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -135,7 +135,7 @@ AltPattern AltPatternNode::make(DFPattern left, DFPattern right) { TVM_REGISTER_NODE_TYPE(AltPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.AltPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern") .set_body_typed(AltPatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -146,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(WildcardPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.WildcardPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern") .set_body_typed([]() { auto w = WildcardPattern(make_object()); return w; @@ -166,7 +166,7 @@ TypePattern TypePatternNode::make(DFPattern pattern, Type type) { TVM_REGISTER_NODE_TYPE(TypePatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.TypePattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern") .set_body_typed(TypePatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -184,7 +184,7 @@ AttrPattern AttrPatternNode::make(DFPattern pattern, Attrs attrs) { TVM_REGISTER_NODE_TYPE(AttrPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.AttrPattern") +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern") .set_body_typed(AttrPatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -203,7 +203,8 @@ DominatorPattern DominatorPatternNode::make(DFPattern parent, DFPattern path, DF TVM_REGISTER_NODE_TYPE(DominatorPatternNode); -TVM_REGISTER_GLOBAL("relay.df_pattern.DominatorPattern").set_body_typed(DominatorPatternNode::make); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern") + .set_body_typed(DominatorPatternNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/ir/dataflow_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc similarity index 99% rename from src/relay/ir/dataflow_functor.cc rename to src/relay/ir/dataflow_pattern_functor.cc index 15d41475d0da..d8202994dbb8 100644 --- a/src/relay/ir/dataflow_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_dataflow_pattern.py similarity index 99% rename from tests/python/relay/test_df_pattern.py rename to tests/python/relay/test_dataflow_pattern.py index c12c7989a7a0..e2c6d7d60ac3 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import relay -from tvm.relay.df_pattern import * +from tvm.relay.dataflow_pattern import * import numpy as np # NB: 1 corresponds to the C++ enum that specicfies this From 6f751697d7b9de08cea7396bdbb609b6fc809f93 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 4 May 2020 09:46:54 -0700 Subject: [PATCH 35/46] Respond to Tianqi's Comments Move IndexedGraph to it's own header, edit python imports --- include/tvm/relay/dataflow_pattern_functor.h | 102 ------- python/tvm/relay/dataflow_pattern/__init__.py | 5 +- src/relay/ir/dataflow_matcher.cc | 1 + src/relay/ir/dataflow_pattern_functor.cc | 245 ---------------- src/relay/ir/indexed_graph.cc | 276 ++++++++++++++++++ src/relay/ir/indexed_graph.h | 137 +++++++++ 6 files changed, 417 insertions(+), 349 deletions(-) create mode 100644 src/relay/ir/indexed_graph.cc create mode 100644 src/relay/ir/indexed_graph.h diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index 82590ade8c68..ac8b35af514a 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -25,11 +25,7 @@ #define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ #include -#include -#include #include -#include -#include namespace tvm { namespace relay { @@ -145,104 +141,6 @@ class DFPatternVisitor : public DFPatternFunctor { std::unordered_set visited_; }; -/*! - * \brief A Wrapper around a templated graph type - * Holds a forward-backward indexed representation of the graph and a dominator tree representation - * of the graph - * - * This class is templated and the implementaiton is in the header file so we can analyze both - * DFPattern and Expr with the same infrastructure. - * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. - */ -template -class IndexedGraph { - public: - /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ - struct Node { - /*! \brief Node Constructor - * \param ref The input graph node - * \param index The index of the node in toplogical order - */ - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} - - /*! \brief The input node */ - const T ref_; - /*! \brief The topological order index */ - const size_t index_; - - /*! \brief A boolean to determine if this node is external to the graph */ - bool is_external_ = false; - /*! \brief The forward inputs of the node */ - std::vector inputs_; - /*! \brief The forward outputs/users of the node */ - std::vector outputs_; - - /*! \brief The depth of the node in the dominator tree */ - size_t depth_; - /*! \brief The dominator parent/final user of the outputs of this node */ - Node* dominator_parent_; - /*! \brief The nodes this node dominates */ - std::vector dominator_children_; - }; - /*! \brief Construct the domination tree inside IndexedGraph */ - void PostDom() { - for (size_t i = topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - auto* current = topological_order_[index].get(); - if (current->is_external_) { - current->depth_ = 1; - current->dominator_parent_ = nullptr; - } else { - auto parent = LeastCommonAncestor(current->outputs_); - current->depth_ = parent ? parent->depth_ + 1 : 1; - current->dominator_parent_ = parent; - parent->dominator_children_.push_back(current); - } - } - } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectHash, ObjectEqual> node_map_; - /*! \brief Topological IndexedGraph Nodes */ - std::vector> topological_order_; - - protected: - /*! \brief Find the least common ancestor of all outputs of a node */ - Node* LeastCommonAncestor(const std::vector& outputs) { - if (outputs.size() == 0) { - return nullptr; - } - auto parent = outputs.at(0); - for (size_t i = 1; i < outputs.size(); ++i) { - parent = LeastCommonAncestor(parent, outputs.at(i)); - } - return parent; - } - - /*! \brief Find the least common ancestor of two nodes */ - Node* LeastCommonAncestor(Node* lhs, Node* rhs) { - if (lhs == nullptr || rhs == nullptr) { - return nullptr; - } - while (lhs != rhs) { - if (lhs->depth_ < rhs->depth_) { - rhs = rhs->dominator_parent_; - } else if (lhs->depth_ > rhs->depth_) { - lhs = lhs->dominator_parent_; - } else { - rhs = rhs->dominator_parent_; - lhs = lhs->dominator_parent_; - } - } - return lhs; - } -}; - -/*! \brief Create an Indexed Graph based on an Expr */ -IndexedGraph CreateIndexedGraph(const Expr& expr); -/*! \brief Create an Indexed Graph based on an DFPattern */ -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); - } // namespace relay } // namespace tvm #endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 9a98ee2e9f0c..4c5e94a2c1c4 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -22,6 +22,7 @@ from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi +import tvm._ffi def register_df_node(type_key=None): @@ -33,9 +34,9 @@ def register_df_node(type_key=None): The type key of the node. """ if not isinstance(type_key, str): - return tvm_ffi.register_object( + return tvm._ffi.register_object( "relay.dataflow_pattern." + type_key.__name__)(type_key) - return tvm_ffi.register_object(type_key) + return tvm._ffi.register_object(type_key) class DFPattern(Node): diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 5ee6f095bfe2..70904282bb00 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -27,6 +27,7 @@ #include #include #include +#include "indexed_graph.h" namespace tvm { namespace relay { diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index d8202994dbb8..c7c34c804449 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -22,10 +22,7 @@ * \brief The dataflow pattern matcher for Relay. */ -#include #include -#include -#include namespace tvm { namespace relay { @@ -76,247 +73,5 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} -// IndexedGraph - -IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ - class Creator : public MixedModeVisitor { - public: - IndexedGraph CreateGraph(const Expr& expr) { - VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; - return std::move(graph_); - } - - protected: - void VisitLeaf(const Expr& expr) override { - MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); - } - IndexedGraph graph_; - size_t index_ = 0; - }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree - * analysis. - * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined - * topological order instead of recursing. - */ - class Annotator : public ExprFunctor { - public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { - // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - ExprFunctor::VisitExpr(node->ref_, nullptr); - } - // do the dominator analysis - graph_.PostDom(); - return std::move(graph_); - } - - /*! Default visitation pushes the parent to the child's ouputs and the child to the parent's - * inputs*/ - void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } - } - - protected: - IndexedGraph graph_; - void VisitExpr_(const VarNode* op, NodePtr parent) override { - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation); - } - } - - void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} - - void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} - - void VisitExpr_(const TupleNode* op, NodePtr parent) override { - for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); - } - } - - void VisitExpr_(const FunctionNode* op, NodePtr parent) override { - for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); - } - - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); - - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); - } - - for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); - } - } - - void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } - - void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - } - - void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { - for (const Type& t : op->inputs) { - this->VisitType(t); - } - this->VisitType(op->belong_to); - } - - void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); - for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); - } - } - - void VisitClause(const Clause& op, NodePtr parent) { - this->VisitPattern(op->lhs); - this->VisitExpr(op->rhs, parent); - } - - void VisitPattern(const Pattern& p) { return; } - - void VisitType(const Type& t) { return; } - }; - return Annotator(Creator().CreateGraph(expr)).Annotate(); -} - -IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ - class Creator : public DFPatternVisitor { - public: - IndexedGraph CreateGraph(const DFPattern& pattern) { - VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; - return std::move(graph_); - } - - protected: - void VisitDFPattern(const DFPattern& pattern) override { - DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); - } - IndexedGraph graph_; - size_t index_ = 0; - }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree - * analysis. - * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined - * topological order instead of recursing. - */ - class Annotator : public DFPatternFunctor { - public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { - // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); - } - graph_.PostDom(); - // do the dominator analysis - return std::move(graph_); - } - - /*! Default visitation pushes the parent to the child's ouputs */ - void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } - } - - protected: - IndexedGraph graph_; - void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); - } - - void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); - } - - void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); - for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); - } - } - void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); - } - - void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} - - void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); - } - - void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { - for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); - } - } - - void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); - } - - void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} - - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} - }; - return Annotator(Creator().CreateGraph(pattern)).Annotate(); -} - } // namespace relay } // namespace tvm diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc new file mode 100644 index 000000000000..74dafc5ec647 --- /dev/null +++ b/src/relay/ir/indexed_graph.cc @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.cc + * \brief Utilties for Creating Indexed Graphs. + */ +#include +#include +#include +#include +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +// IndexedGraph + +IndexedGraph CreateIndexedGraph(const Expr& expr) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ + class Creator : public MixedModeVisitor { + public: + IndexedGraph CreateGraph(const Expr& expr) { + VisitExpr(expr); + graph_.node_map_[expr]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + auto node = std::make_shared::Node>(expr, index_++); + graph_.node_map_[expr] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public ExprFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + ExprFunctor::VisitExpr(node->ref_, nullptr); + } + // do the dominator analysis + graph_.PostDom(); + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs and the child to the parent's + * inputs*/ + void VisitExpr(const Expr& expr, NodePtr parent) override { + auto current = graph_.node_map_[expr]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitExpr_(const VarNode* op, NodePtr parent) override { + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation); + } + } + + void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + + void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + + void VisitExpr_(const TupleNode* op, NodePtr parent) override { + for (auto field : op->fields) { + this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const FunctionNode* op, NodePtr parent) override { + for (auto param : op->params) { + this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + } + + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const CallNode* op, NodePtr parent) override { + this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const LetNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const IfNode* op, NodePtr parent) override { + this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + + void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { + this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefReadNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { + for (const Type& t : op->inputs) { + this->VisitType(t); + } + this->VisitType(op->belong_to); + } + + void VisitExpr_(const MatchNode* op, NodePtr parent) override { + this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); + for (const Clause& c : op->clauses) { + this->VisitClause(c, graph_.node_map_[GetRef(op)]); + } + } + + void VisitClause(const Clause& op, NodePtr parent) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs, parent); + } + + void VisitPattern(const Pattern& p) { return; } + + void VisitType(const Type& t) { return; } + }; + return Annotator(Creator().CreateGraph(expr)).Annotate(); +} + +IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ + class Creator : public DFPatternVisitor { + public: + IndexedGraph CreateGraph(const DFPattern& pattern) { + VisitDFPattern(pattern); + graph_.node_map_[pattern]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitDFPattern(const DFPattern& pattern) override { + DFPatternVisitor::VisitDFPattern(pattern); + auto node = std::make_shared::Node>(pattern, index_++); + graph_.node_map_[pattern] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public DFPatternFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + } + graph_.PostDom(); + // do the dominator analysis + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs */ + void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { + auto current = graph_.node_map_[pattern]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + for (auto arg : op->args) { + VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + } + } + void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + for (auto field : op->fields) { + VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h new file mode 100644 index 000000000000..635526538425 --- /dev/null +++ b/src/relay/ir/indexed_graph.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_INDEXED_GRAPH_H_ +#define TVM_RELAY_INDEXED_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A Wrapper around a templated graph type + * Holds a forward-backward indexed representation of the graph and a dominator tree representation + * of the graph + * + * This class is templated and the implementaiton is in the header file so we can analyze both + * DFPattern and Expr with the same infrastructure. + * + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + */ +template +class IndexedGraph { + public: + /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + struct Node { + /*! \brief Node Constructor + * \param ref The input graph node + * \param index The index of the node in toplogical order + */ + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + + /*! \brief The input node */ + const T ref_; + /*! \brief The topological order index */ + const size_t index_; + + /*! \brief A boolean to determine if this node is external to the graph */ + bool is_external_ = false; + /*! \brief The forward inputs of the node */ + std::vector inputs_; + /*! \brief The forward outputs/users of the node */ + std::vector outputs_; + + /*! \brief The depth of the node in the dominator tree */ + size_t depth_; + /*! \brief The dominator parent/final user of the outputs of this node */ + Node* dominator_parent_; + /*! \brief The nodes this node dominates */ + std::vector dominator_children_; + }; + /*! \brief Construct the domination tree inside IndexedGraph */ + void PostDom() { + for (size_t i = topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + auto* current = topological_order_[index].get(); + if (current->is_external_) { + current->depth_ = 1; + current->dominator_parent_ = nullptr; + } else { + auto parent = LeastCommonAncestor(current->outputs_); + current->depth_ = parent ? parent->depth_ + 1 : 1; + current->dominator_parent_ = parent; + parent->dominator_children_.push_back(current); + } + } + } + /*! \brief Map of input nodes to IndexedGraph Nodes */ + std::unordered_map, ObjectHash, ObjectEqual> node_map_; + /*! \brief Topological IndexedGraph Nodes */ + std::vector> topological_order_; + + protected: + /*! \brief Find the least common ancestor of all outputs of a node */ + Node* LeastCommonAncestor(const std::vector& outputs) { + if (outputs.size() == 0) { + return nullptr; + } + auto parent = outputs.at(0); + for (size_t i = 1; i < outputs.size(); ++i) { + parent = LeastCommonAncestor(parent, outputs.at(i)); + } + return parent; + } + + /*! \brief Find the least common ancestor of two nodes */ + Node* LeastCommonAncestor(Node* lhs, Node* rhs) { + if (lhs == nullptr || rhs == nullptr) { + return nullptr; + } + while (lhs != rhs) { + if (lhs->depth_ < rhs->depth_) { + rhs = rhs->dominator_parent_; + } else if (lhs->depth_ > rhs->depth_) { + lhs = lhs->dominator_parent_; + } else { + rhs = rhs->dominator_parent_; + lhs = lhs->dominator_parent_; + } + } + return lhs; + } +}; + +/*! \brief Create an Indexed Graph based on an Expr */ +IndexedGraph CreateIndexedGraph(const Expr& expr); +/*! \brief Create an Indexed Graph based on an DFPattern */ +IndexedGraph CreateIndexedGraph(const DFPattern& pattern); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_INDEXED_GRAPH_H_ From c4725f2161fe802d2e39da2194faea4b4afcff65 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 4 May 2020 10:25:38 -0700 Subject: [PATCH 36/46] fix lint --- include/tvm/relay/dataflow_pattern_functor.h | 1 + python/tvm/relay/dataflow_pattern/__init__.py | 2 +- src/relay/ir/indexed_graph.h | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index ac8b35af514a..21b8d499b2a3 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -26,6 +26,7 @@ #include #include +#include namespace tvm { namespace relay { diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 4c5e94a2c1c4..cde229c6604a 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -16,13 +16,13 @@ # under the License. """The Relay Pattern Language and tooling.""" from tvm.relay import Expr +import tvm._ffi from ...ir.base import Node from ...ir import make_node from ...runtime import Object from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi -import tvm._ffi def register_df_node(type_key=None): diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index 635526538425..c569c16b06a5 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -21,8 +21,8 @@ * \file src/relay/ir/indexed_graph.h * \brief A pattern matcher for matching dataflow properties. */ -#ifndef TVM_RELAY_INDEXED_GRAPH_H_ -#define TVM_RELAY_INDEXED_GRAPH_H_ +#ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ +#define TVM_RELAY_IR_INDEXED_GRAPH_H_ #include #include @@ -134,4 +134,4 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_INDEXED_GRAPH_H_ +#endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ From ebc72614ca26149077cc2203c973f857bd876d69 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 8 May 2020 10:43:11 -0700 Subject: [PATCH 37/46] refactor to respond to zhiic's comments --- include/tvm/relay/dataflow_pattern.h | 32 ++---- src/relay/ir/dataflow_matcher.cc | 18 ++-- src/relay/ir/dataflow_pattern.cc | 156 ++++++++++++++------------- 3 files changed, 98 insertions(+), 108 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index f2ff7c9d71ea..6f2faa8aac4f 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -26,7 +26,6 @@ #include #include -#include namespace tvm { namespace relay { @@ -74,7 +73,7 @@ class ExprPatternNode : public DFPatternNode { */ class ExprPattern : public DFPattern { public: - TVM_DLL ExprPattern(Expr expr); + TVM_DLL explicit ExprPattern(Expr expr); TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); }; @@ -89,7 +88,7 @@ class VarPatternNode : public DFPatternNode { /*! * \brief The name of the Var (optional). */ - std::string name; + String name; /*! * \brief type annotation of the variable. * This field records user provided type annotation of the Var. @@ -98,7 +97,7 @@ class VarPatternNode : public DFPatternNode { Type type_annotation; /*! \return The name hint of the variable */ - const std::string& name_hint() const { + const String& name_hint() const { return name; } @@ -107,14 +106,13 @@ class VarPatternNode : public DFPatternNode { v->Visit("type_annotation", &type_annotation); } - TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation); - static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); }; class VarPattern : public DFPattern { public: + TVM_DLL VarPattern(String name_hint, Type type_annotation); TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); }; @@ -167,15 +165,13 @@ class CallPatternNode : public DFPatternNode { v->Visit("type_args", &type_args); } - TVM_DLL static CallPattern make(DFPattern op, Array args, Attrs attrs, - Array type_args); - static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); }; class CallPattern : public DFPattern { public: + TVM_DLL CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args); TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); }; @@ -191,14 +187,13 @@ class TuplePatternNode : public DFPatternNode { v->Visit("fields", &fields); } - TVM_DLL static TuplePattern make(tvm::Array fields); - static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); }; class TuplePattern : public DFPattern { public: + TVM_DLL explicit TuplePattern(tvm::Array fields); TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); }; @@ -217,14 +212,13 @@ class TupleGetItemPatternNode : public DFPatternNode { v->Visit("index", &index); } - TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index); - static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); }; class TupleGetItemPattern : public DFPattern { public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); }; @@ -244,8 +238,6 @@ class AltPatternNode : public DFPatternNode { v->Visit("right", &right); } - TVM_DLL static AltPattern make(DFPattern left, DFPattern right); - static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); }; @@ -255,6 +247,7 @@ class AltPatternNode : public DFPatternNode { */ class AltPattern : public DFPattern { public: + TVM_DLL AltPattern(DFPattern left, DFPattern right); TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); }; @@ -294,8 +287,6 @@ class TypePatternNode : public DFPatternNode { v->Visit("type", &type); } - TVM_DLL static TypePattern make(DFPattern pattern, Type type); - static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); }; @@ -305,6 +296,7 @@ class TypePatternNode : public DFPatternNode { */ class TypePattern : public DFPattern { public: + TVM_DLL TypePattern(DFPattern pattern, Type type); TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); }; @@ -324,8 +316,6 @@ class AttrPatternNode : public DFPatternNode { v->Visit("attrs", &attrs); } - TVM_DLL static AttrPattern make(DFPattern pattern, Attrs attrs); - static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); }; @@ -335,6 +325,7 @@ class AttrPatternNode : public DFPatternNode { */ class AttrPattern : public DFPattern { public: + TVM_DLL AttrPattern(DFPattern pattern, Attrs attrs); TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); }; @@ -359,8 +350,6 @@ class DominatorPatternNode : public DFPatternNode { v->Visit("child", &child); } - TVM_DLL static DominatorPattern make(DFPattern parent, DFPattern path, DFPattern child); - static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); }; @@ -370,6 +359,7 @@ class DominatorPatternNode : public DFPatternNode { */ class DominatorPattern : public DFPattern { public: + TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); }; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 70904282bb00..17c0b61d34ba 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -218,11 +218,10 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex is_expr_op(call_node->args[1], "divide"))) { bool out = false; for (size_t arg_id = 0; arg_id < 2; ++arg_id) { - auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]}, - op->attrs, op->type_args); - auto mul = - CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, - arg_node->attrs, arg_node->type_args); + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs, + op->type_args); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, + arg_node->attrs, arg_node->type_args); out = VisitDFPattern(mul, expr); if (out) { return true; @@ -241,11 +240,10 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && (is_expr_op(call_node->args[0], "multiply") || is_expr_op(call_node->args[1], "multiply"))) { - auto mul = - CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, - op->attrs, op->type_args); - auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]}, - arg_node->attrs, arg_node->type_args); + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, + op->attrs, op->type_args); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs, + arg_node->type_args); return VisitDFPattern(div, expr); } } diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 4dfda55f5d2d..826a035ca6ba 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -34,177 +34,179 @@ ExprPattern::ExprPattern(Expr expr) { TVM_REGISTER_NODE_TYPE(ExprPatternNode); -TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern") -.set_body_typed([](Expr e) { - return ExprPattern(e); - }); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr e) { + return ExprPattern(e); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->Print(node->expr); - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->expr); + }); -VarPattern VarPatternNode::make(std::string name_hint, Type type_annotation) { +VarPattern::VarPattern(String name_hint, Type type_annotation) { ObjectPtr n = make_object(); n->name = std::move(name_hint); n->type_annotation = std::move(type_annotation); - return VarPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(VarPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern") -.set_body_typed(static_cast(VarPatternNode::make)); + .set_body_typed([](String name_hint, Type type_annotation) { + return VarPattern(name_hint, type_annotation); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "VarPattern(" << node->name_hint(); - if (node->type_annotation.defined()) { - p->stream << ", ty="; - p->Print(node->type_annotation); - } - p->stream << ")"; - }); - -CallPattern CallPatternNode::make(DFPattern op, Array args, Attrs attrs, - Array type_args) { + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "VarPattern(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); + +CallPattern::CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); - return CallPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(CallPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") -.set_body_typed(CallPatternNode::make); + .set_body_typed([](DFPattern op, Array args, Attrs attrs, Array type_args) { + return CallPattern(op, args, attrs, type_args); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs - << ", " << node->type_args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs + << ", " << node->type_args << ")"; + }); -TuplePattern TuplePatternNode::make(tvm::Array fields) { +TuplePattern::TuplePattern(tvm::Array fields) { ObjectPtr n = make_object(); n->fields = std::move(fields); - return TuplePattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TuplePatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern") -.set_body_typed(TuplePatternNode::make); + .set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TuplePattern(" << node->fields << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TuplePattern(" << node->fields << ")"; + }); -TupleGetItemPattern TupleGetItemPatternNode::make(DFPattern tuple, int index) { +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; - return TupleGetItemPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern") -.set_body_typed(TupleGetItemPatternNode::make); + .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; + }); -AltPattern AltPatternNode::make(DFPattern left, DFPattern right) { +AltPattern::AltPattern(DFPattern left, DFPattern right) { ObjectPtr n = make_object(); n->left = std::move(left); n->right = std::move(right); - return AltPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(AltPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern") -.set_body_typed(AltPatternNode::make); + .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; + }); TVM_REGISTER_NODE_TYPE(WildcardPatternNode); -TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern") -.set_body_typed([]() { - auto w = WildcardPattern(make_object()); - return w; - }); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "*"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "*"; + }); -TypePattern TypePatternNode::make(DFPattern pattern, Type type) { +TypePattern::TypePattern(DFPattern pattern, Type type) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->type = std::move(type); - return TypePattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TypePatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern") -.set_body_typed(TypePatternNode::make); + .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; + }); -AttrPattern AttrPatternNode::make(DFPattern pattern, Attrs attrs) { +AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->attrs = std::move(attrs); - return AttrPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(AttrPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern") -.set_body_typed(AttrPatternNode::make); + .set_body_typed([](DFPattern pattern, Attrs attrs) { return AttrPattern(pattern, attrs); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; + }); -DominatorPattern DominatorPatternNode::make(DFPattern parent, DFPattern path, DFPattern child) { +DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) { ObjectPtr n = make_object(); n->parent = std::move(parent); n->path = std::move(path); n->child = std::move(child); - return DominatorPattern(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DominatorPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern") - .set_body_typed(DominatorPatternNode::make); + .set_body_typed([](DFPattern parent, DFPattern path, DFPattern child) { + return DominatorPattern(parent, path, child); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { From bf5283f89ddc730ba45a17613d470287210bfeb9 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 8 May 2020 10:49:48 -0700 Subject: [PATCH 38/46] refactor callback node --- include/tvm/relay/dataflow_matcher.h | 3 +-- src/relay/ir/dataflow_matcher.cc | 8 +++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 59ff9cd2776a..7c46d77c32db 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -46,8 +46,6 @@ class DFPatternCallbackNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) {} - TVM_DLL static DFPatternCallback make(DFPattern pattern, PackedFunc callback); - static constexpr const char* _type_key = "DFPatternCallbackNode"; TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); }; @@ -58,6 +56,7 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 17c0b61d34ba..f88a6eb710ad 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -522,17 +522,19 @@ class PatternGrouper : protected MixedModeVisitor { // Rewrite -DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) { ObjectPtr n = make_object(); n->pattern_ = std::move(pattern); n->function_ = std::move(function); - return DFPatternCallback(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed(DFPatternCallbackNode::make); + .set_body_typed([](DFPattern pattern, PackedFunc function) { + return DFPatternCallback(pattern, function); + }); /* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback * function to rewrite those matches From 01a231e63593d387a35e8e1a2b8818ace4fef69a Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 11:38:30 -0700 Subject: [PATCH 39/46] respond to review comments --- docs/langref/relay_pattern.rst | 8 +++----- include/tvm/relay/dataflow_pattern_functor.h | 2 -- src/relay/ir/dataflow_matcher.cc | 11 ++++------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index cc4c80e79e57..d955366a865e 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -28,8 +28,6 @@ Such a language is not just useful for building a rewriter but also providing ex In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box. -Recently there has been lots of discussion on similar issues in the community, and we wanted to gather feedback and hopefully collaborate on a design that can benefit everyone working in this space. This RFC focuses on the pattern language with future applications to come later. - Examples ======== @@ -69,13 +67,13 @@ The next example is matching a diamond with two inputs at the top of the diamond # Check assert diamond.match(out) -The final example we would like to match which is not yet implemented in the prototype is matching diamonds with a post-dominator relationship. Our plan is to embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d:: +The final example is matching diamonds with a post-dominator relationship. We embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to be able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d:: def test_match_dom_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) reduction = is_op('add')(wildcard(), wildcard()) - diamond = dominates(is_conv2d, is_elemwise, reduction) + diamond = dominates(is_conv2d, is_elemwise, reduction) # Expr inp = relay.var('input') @@ -140,4 +138,4 @@ Either match the first pattern or the second pattern. Domination ****************** -Match the parent pattern for the route node and then check that the child pattern holds for each child along the domination path. +Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index 21b8d499b2a3..cac914220d9b 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -57,8 +57,6 @@ class DFPatternFunctor { using FType = tvm::NodeFunctor; public: - /*! \brief the result type of this functor */ - using result_type = R; /*! \brief virtual destructor */ virtual ~DFPatternFunctor() {} /*! diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index f88a6eb710ad..bde445cf4cec 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -127,7 +127,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } break; default: - throw "Unsupported type"; + CHECK(false) << "Unsupported type in Type Pattern Node"; } } } @@ -256,7 +256,6 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { - bool out = true; auto call_node = expr.as(); for (auto node : expr_graph_.node_map_[expr]->inputs_) { if (!(call_node && node->ref_ == call_node->op)) { @@ -265,15 +264,13 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e return true; } else { memoize_ = false; - if (VisitDFPattern(op->path, node->ref_)) { - out &= MatchesPath(op, node->ref_); - } else { + if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { return false; } } } } - return out; + return true; } // Iteratively ensure that the parent is dominated somewhere by the child or the path @@ -368,7 +365,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match") * This class creates a number of groups of matched expressions, ensures they don't overlap, and * returns them to the caller for post-analysis rewriting. * - * This is primarily needed to suppor the post-dominator analysis required for dominator pattern + * This is primarily needed to support the post-dominator analysis required for dominator pattern * matching. */ class PatternGrouper : protected MixedModeVisitor { From 911ba2f47571f78b2a219ce63ada66f686eae143 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 11:41:46 -0700 Subject: [PATCH 40/46] upgrade from clang-format-6 to clang-format-10 --- include/tvm/relay/dataflow_matcher.h | 1 + include/tvm/relay/dataflow_pattern.h | 17 ++++------------- include/tvm/relay/dataflow_pattern_functor.h | 1 + src/relay/ir/dataflow_matcher.cc | 2 ++ src/relay/ir/indexed_graph.cc | 3 ++- src/relay/ir/indexed_graph.h | 1 + 6 files changed, 11 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 7c46d77c32db..2c3601c368f8 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -26,6 +26,7 @@ #include #include + #include #include diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 6f2faa8aac4f..a8db51f74574 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -57,9 +57,7 @@ class ExprPatternNode : public DFPatternNode { /*! \brief The expression to match. */ Expr expr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("expr", &expr); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); @@ -77,7 +75,6 @@ class ExprPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); }; - /*! * \brief A Pattern to Match a Relay Variable */ @@ -97,9 +94,7 @@ class VarPatternNode : public DFPatternNode { Type type_annotation; /*! \return The name hint of the variable */ - const String& name_hint() const { - return name; - } + const String& name_hint() const { return name; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -183,9 +178,7 @@ class TuplePatternNode : public DFPatternNode { /*! \brief the fields of the tuple */ tvm::Array fields; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("fields", &fields); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); @@ -206,7 +199,6 @@ class TupleGetItemPatternNode : public DFPatternNode { /*! \brief which value to get */ int index; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple", &tuple); v->Visit("index", &index); @@ -251,7 +243,6 @@ class AltPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); }; - /*! * \brief Wildcard Pattern. */ @@ -359,7 +350,7 @@ class DominatorPatternNode : public DFPatternNode { */ class DominatorPattern : public DFPattern { public: - TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); + TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); }; diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index cac914220d9b..b6080febec92 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -25,6 +25,7 @@ #define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ #include + #include #include diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index bde445cf4cec..51b49eb87cfd 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -26,7 +26,9 @@ #include #include #include + #include + #include "indexed_graph.h" namespace tvm { diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 74dafc5ec647..eb328eb30793 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -21,11 +21,12 @@ * \file src/relay/ir/indexed_graph.cc * \brief Utilties for Creating Indexed Graphs. */ +#include "indexed_graph.h" + #include #include #include #include -#include "indexed_graph.h" namespace tvm { namespace relay { diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index c569c16b06a5..d2524340f971 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -25,6 +25,7 @@ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ #include + #include #include #include From c08d91efddc7134995f48917e7b5f90147349c16 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 11:54:11 -0700 Subject: [PATCH 41/46] fix text headers --- docs/langref/relay_pattern.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index d955366a865e..7f81b9b48299 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -110,32 +110,32 @@ Expression Pattern Match a literal expression. Wildcard -****************** +******** Match any expression. Type Pattern -****************** +************ Check that the expression matched by the nested pattern has a particular type. Attribute Pattern -****************** +***************** Check that the operator matched by the pattern has an attribute with a particular value. Input -****************** +***** Check that the expression is an input, i.e has no parents and is a variable. Alternate -****************** +********* Either match the first pattern or the second pattern. Domination -****************** +********** Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. From 46f313a705917bfd43486eb9372813930f4d23a1 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 11:54:24 -0700 Subject: [PATCH 42/46] Revert "move InferType Function" This reverts commit f41dfc11ff4fec345558d8c7b9cd8cdfc11a4cbb. --- include/tvm/relay/transform.h | 9 --------- src/relay/ir/dataflow_matcher.cc | 10 ++++++++++ src/relay/transforms/type_infer.cc | 5 ----- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 62799fbaf299..9a8ca8421997 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -375,15 +375,6 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); */ TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); -/*! - * \brief Infer the type of an expression base on it's inputs. - * - * \param expr the Expr. - * - * \return A type checked Expr with its checked_type field populated. - */ -TVM_DLL Expr InferType(const Expr& expr); - /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This * function is used as a helper function to rewrtie an expression in a pass. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 51b49eb87cfd..eb6a9209f6a8 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -337,6 +337,16 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e return matches; } +Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index a7334eb09525..078248483587 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -742,11 +742,6 @@ Function InferType(const Function& func, const IRModule& mod, const GlobalVar& v return Downcast(func_ret); } -Expr InferType(const Expr& expr) { - auto mod = IRModule::FromExpr(expr); - return InferType(expr, mod); -} - namespace transform { Pass InferType() { From 50be5b3a97ab4da905906374c10e709410d49654 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 14:51:44 -0700 Subject: [PATCH 43/46] add optional syntactic sugar --- python/tvm/relay/dataflow_pattern/__init__.py | 17 ++++++ tests/python/relay/test_dataflow_pattern.py | 59 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index cde229c6604a..675aaa9be647 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -148,6 +148,23 @@ def dominates(self, parent, path=None): path = wildcard() return DominatorPattern(parent, path, self) + def optional(self, option_constructor): + """ + Create a dominator for this partern + + Parameters + ---------- + option_constructor: function + A function that takes a single Pattern parameter and returns + a constructed pattern matching the option + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting Pattern + """ + return self | option_constructor(self) + def is_input(name: str = "") -> DFPattern: """ diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index e2c6d7d60ac3..a93a39be14d0 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -126,6 +126,65 @@ def test_no_match_call(): add_pattern = is_op('add')(wildcard(), wildcard()) assert not add_pattern.match(x - y) +def test_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")( + is_op("nn.conv2d")(wildcard(), wildcard() + ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) + ) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + assert pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert pattern.match(relu) + + pattern = is_op("nn.conv2d")(wildcard(), wildcard()) + pattern = pattern.optional(is_op('nn.relu')).optional(is_op("tanh")) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + tanh = relay.op.tanh(conv2d) + tanh2 = relay.op.tanh(relu) + relu2 = relay.op.nn.relu(tanh) + assert pattern.match(conv2d) + assert pattern.match(relu) + assert pattern.match(tanh) + assert pattern.match(tanh2) + assert not pattern.match(relu2) + +def test_no_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")( + is_op("nn.conv2d")(wildcard(), wildcard() + ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) + ) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = conv2d + w + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + def test_match_tuple(): x = relay.var('x') y = relay.var('y') From c22c51827f7e980c392c926c4ce7b74356738ac0 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 12 May 2020 16:03:35 -0700 Subject: [PATCH 44/46] fix a comment --- python/tvm/relay/dataflow_pattern/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 675aaa9be647..ca324bc444ec 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -130,7 +130,7 @@ def partition(self, expr: Expr) -> bool: def dominates(self, parent, path=None): """ - Create a dominator for this partern + Create a dominator for this pattern Parameters ---------- @@ -150,7 +150,7 @@ def dominates(self, parent, path=None): def optional(self, option_constructor): """ - Create a dominator for this partern + Create a optional user of this pattern Parameters ---------- From e83b86c53390129251e6e7f810b8f01f1280ad96 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 13 May 2020 13:09:11 -0700 Subject: [PATCH 45/46] fix comment typos --- src/relay/ir/indexed_graph.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index eb328eb30793..79ec57426d66 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -35,7 +35,7 @@ namespace relay { IndexedGraph CreateIndexedGraph(const Expr& expr) { using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ + /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ class Creator : public MixedModeVisitor { public: IndexedGraph CreateGraph(const Expr& expr) { @@ -54,7 +54,7 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { IndexedGraph graph_; size_t index_ = 0; }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree * analysis. * * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined From 0878bb288cc26dc4b0da16018f16b8d21408036f Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 14 May 2020 14:19:05 -0700 Subject: [PATCH 46/46] respond to @masahi's comments --- include/tvm/relay/dataflow_matcher.h | 31 ++++++++++++++++++ include/tvm/relay/dataflow_pattern_functor.h | 4 +-- src/relay/ir/dataflow_matcher.cc | 34 ++++++++------------ 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 2c3601c368f8..58aa6400b650 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -61,6 +61,37 @@ class DFPatternCallback : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; +/*! + * \brief Determine if a pattern matches an expression + * + * \param pattern The pattern to match + * \param expr The expression to match + * + * \return Return true if the pattern and the expression match, return false otherwise. + */ +bool MatchPattern(DFPattern pattern, Expr expr); + +/*! + * \brief Rewrite an expression based on some number of DFPatternCallbacks + * + * \param callbacks An array of DFPatternCallback Nodes + * \param expr The expression to rewrite + * + * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the + * functions inside the callbacks + */ +Expr RewritePatterns(Array callbacks, Expr expr); + +/*! + * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls + * + * \param pattern The pattern to match + * \param expr The expression to patition + * + * \return Return the paritioned Expr. + */ +Expr PartitionPattern(DFPattern pattern, Expr expr); + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index b6080febec92..05c2147c2c49 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -18,8 +18,8 @@ */ /*! - * \file tvm/relay/dataflow_matcher.h - * \brief A pattern matcher for matching dataflow properties. + * \file tvm/relay/dataflow_pattern_functor.h + * \brief A set of passes for operating on pattern graphs. */ #ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ #define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index eb6a9209f6a8..81fc4f03d886 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -64,7 +64,6 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectHash, ObjectEqual> memo_; std::vector matched_nodes_; IndexedGraph expr_graph_; - IndexedGraph pattern_graph_; bool memoize_ = true; }; @@ -298,7 +297,6 @@ bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Exp } bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { - pattern_graph_ = CreateIndexedGraph(GetRef(op)); if (VisitDFPattern(op->child, expr)) { bool matches_path = MatchesPath(op, expr); memoize_ = true; @@ -367,10 +365,11 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match") - .set_body_typed([](DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); - }); +bool MatchPattern(DFPattern pattern, Expr expr) { + return DFPatternMatcher(expr).Match(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); /* \brief PatternGrouper does pre-rewriting pattern matching and analysis * @@ -391,15 +390,12 @@ class PatternGrouper : protected MixedModeVisitor { Array args; }; - /* \brief Return the discovered groups */ - const std::vector& GetGroups() { return this->groups_; } - /* \brief Return the group assignments of expressions */ const std::unordered_map& GetGIDAssignments() { return gid_assignments_; } /* \brief Group expressions that match the pattern */ - void GroupMatches(const DFPattern& pattern, const Expr& pre) { + const std::vector& GroupMatches(const DFPattern& pattern, const Expr& pre) { groups_ = {Group()}; gid_assignments_.clear(); visit_counter_.clear(); @@ -409,6 +405,7 @@ class PatternGrouper : protected MixedModeVisitor { auto matcher = DFPatternMatcher(pre); matcher_ = &matcher; this->VisitExpr(pre); + return this->groups_; } protected: @@ -438,7 +435,7 @@ class PatternGrouper : protected MixedModeVisitor { /* \brief Create a group based on a matched expression */ void CreateGroup(const Expr& expr) { - var_number_ = 0; + int var_number = 0; auto node_map = matcher_->GetMemo(); @@ -468,12 +465,12 @@ class PatternGrouper : protected MixedModeVisitor { for (auto match : matches) { if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && match.as() == nullptr && match.as() == nullptr) { - inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + - std::to_string(var_number_), - NullValue()); + inputs[match] = Var( + "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); group.args.push_back(match); params.push_back(inputs[match]); - var_number_++; + var_number++; } } } @@ -525,7 +522,6 @@ class PatternGrouper : protected MixedModeVisitor { DFPatternMatcher* matcher_ = nullptr; IndexedGraph pattern_graph_; int gid_ = 0; - int var_number_ = 0; int graph_number_ = 0; }; @@ -565,8 +561,7 @@ class PatternRewriter : protected MixedModeMutator { for (auto callback : callbacks) { callback_ = callback; auto grouper = PatternGrouper(); - grouper.GroupMatches(callback_->pattern_, post); - groups_ = grouper.GetGroups(); + groups_ = grouper.GroupMatches(callback_->pattern_, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); post = this->VisitExpr(post); @@ -619,8 +614,7 @@ class PatternPartitioner : protected MixedModeMutator { public: Expr Partition(const DFPattern& pattern, const Expr& pre) { auto grouper = PatternGrouper(); - grouper.GroupMatches(pattern, pre); - groups_ = grouper.GetGroups(); + groups_ = grouper.GroupMatches(pattern, pre); gid_assignments_ = grouper.GetGIDAssignments(); return this->VisitExpr(pre); }