From 8a26c94840d80324fcb81fc27ca13ee16e339f55 Mon Sep 17 00:00:00 2001 From: Mercy Date: Sun, 18 Nov 2018 19:29:02 -0800 Subject: [PATCH 01/10] [RELAY] Finish alter op pass --- include/tvm/relay/op_attr_types.h | 12 ++ python/tvm/__init__.py | 1 + python/tvm/attrs.py | 40 ++++ python/tvm/relay/base.py | 14 ++ python/tvm/relay/ir_pass.py | 19 ++ python/tvm/relay/op/__init__.py | 4 +- python/tvm/relay/op/op.py | 22 ++- python/tvm/relay/op/op_attrs.py | 9 + src/lang/attrs.cc | 6 + src/relay/op/nn/pad.cc | 174 +++++++++--------- src/relay/pass/alter_op_layout.cc | 65 +++++++ src/relay/pass/fold_scale_axis.cc | 4 +- .../python/relay/test_pass_alter_op_layout.py | 70 +++++++ 13 files changed, 348 insertions(+), 92 deletions(-) create mode 100644 python/tvm/attrs.py create mode 100644 python/tvm/relay/op/op_attrs.py create mode 100644 src/relay/pass/alter_op_layout.cc create mode 100644 tests/python/relay/test_pass_alter_op_layout.py diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 3d9fa56855c3..22cf892565c7 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -86,6 +86,18 @@ using FTVMSchedule = runtime::TypedPackedFunc< const Array& outs, const Target& target)>; +/*! + * \brief Alternate the layout of operators or replace the + * operator with other expressions. + * + * \param attrs The attribute of the node. + * \param inputs The arguments of this operator. + * \return new_expr The modified expression. + */ +using FTVMAlterOpLayout = runtime::TypedPackedFunc< + Expr(const Attrs& attrs, + const Array& args)>; + /*! * \brief Forward rewriting rule for a specific op. * diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index e202c5adb967..67dd54d1db4d 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -13,6 +13,7 @@ from . import schedule from . import module from . import node +from . import attrs from . import ir_builder from . import target from . import generic diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py new file mode 100644 index 000000000000..f4132652dfd8 --- /dev/null +++ b/python/tvm/attrs.py @@ -0,0 +1,40 @@ +""" TVM Attribute module, which is mainly used for defining attributes of operators""" +from ._ffi.node import NodeBase, register_node as _register_tvm_node +from ._ffi.function import _init_api +from . import _api_internal + + +@_register_tvm_node +class Attrs(NodeBase): + """Attribute node, which is mainly use for defining attributes of relay operators. + + Used by python registration of compute and schedule function. + Attrs is passed as the first argument to schedule and compute function. + """ + def list_field_info(self): + """ Get fields information + + Returns + ------- + infos: list of AttrFieldInfo + List of field information + """ + return _api_internal._AttrsListFieldInfo(self) + + def keys(self): + """Get list of names in the attribute. + + Returns + ------- + keys : list of str + List of keys + """ + fields = self.list_field_info() + for field in fields: + yield field.name + + def __getitem__(self, item): + return self.__getattr__(item) + + +_init_api("tvm.attrs") diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 83aa4ec2cdd0..f1105fe4f0d9 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -21,6 +21,20 @@ def register_relay_node(type_key=None): return _register_tvm_node(type_key) +def register_relay_attr_node(type_key=None): + """register relay attribute node + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay.attrs." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + class RelayNode(NodeBase): """Base class of all relay node.""" def astext(self, show_meta_data=True, annotate=None): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 6297e366070f..f3bb586fc211 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -321,3 +321,22 @@ def combine_parallel_conv2d(expr): Transformed expression """ return _ir_pass.CombineParallelConv2D(expr) + + +def alter_op_layout(expr): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + transformed_expr : tvm.relay.Expr + Transformed expression with alternated layout. + """ + return _ir_pass.AlterOpLayout(expr) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index b32db4c23f3e..4a6dfd9f7335 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,7 +1,8 @@ #pylint: disable=wildcard-import, redefined-builtin """Relay core operators.""" # operator defs -from .op import get, register, register_schedule, register_compute, Op +from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \ + Op # Operators from .reduce import * @@ -10,6 +11,7 @@ from . import nn from . import image from . import vision +from . import op_attrs # operator registry from . import _tensor diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index c777a82462c8..dd3af9c44e42 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10): op_name : str The name of the op. - schedule : function + schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule The schedule function. level : int @@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10): op_name : str The name of the op. - compute : function + compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target) + -> List[Tensor] The compute function. level : int @@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10): return register(op_name, "FTVMCompute", compute, level) +def register_alter_op_layout(op_name, alter_layout=None, level=10): + """Register alter op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for changing the layout or replacing the operator + + level : int + The priority level + """ + return register(op_name, "FTVMAlterOpLayout", alter_layout, level) + + def register_pattern(op_name, pattern, level=10): """Register operator pattern for an op. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py new file mode 100644 index 000000000000..af50905295a3 --- /dev/null +++ b/python/tvm/relay/op/op_attrs.py @@ -0,0 +1,9 @@ +"""The attributes node used for Relay operators""" + +from ...attrs import Attrs +from ..base import register_relay_attr_node + +@register_relay_attr_node +class Conv2DAttrs(Attrs): + """Attribute of a Convolution Operator""" + pass diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 3b273f4939ef..1daf1e792553 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -3,6 +3,7 @@ * \file attrs.cc */ #include +#include #include "attr_functor.h" namespace tvm { @@ -321,4 +322,9 @@ bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const { return equal(this->dict, static_cast(other)->dict); } +TVM_REGISTER_API("_AttrsListFieldInfo") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Attrs()->ListFieldInfo(); +}); + } // namespace tvm diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 6e02d74e6ea8..8f3e758726fd 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -1,87 +1,87 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file pad.cc - * \brief Implementation of operator pad - */ -#include -#include -#include -#include -#include "../layout.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_NODE_TYPE(PadAttrs); - -bool PadRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const PadAttrs* param = attrs.as(); - CHECK(param != nullptr); - - // check that pad widths match lengths - CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; - - // each pad width element should be a pair of positive integers - std::vector oshape; - for (size_t i = 0; i < param->pad_width.size(); i++) { - CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; - - auto width1 = as_const_int(param->pad_width[i][0]); - auto width2 = as_const_int(param->pad_width[i][1]); - CHECK(width1 != nullptr); - CHECK(width2 != nullptr); - - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; - - auto padding = make_const(data->shape[i].type(), *width1 + *width2); - oshape.push_back(data->shape[i] + padding); - } - - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), - data->dtype)); - return true; -} - -// Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, Array > pad_width, double pad_value) { - auto attrs = make_node(); - attrs->pad_value = pad_value; - attrs->pad_width = std::move(pad_width); - static const Op& op = Op::Get("nn.pad"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_API("relay.op.nn._make.pad") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakePad, args, rv); - }); - -RELAY_REGISTER_OP("nn.pad") -.describe(R"code(Pad for n-D tensor. - -)code" TVM_ADD_FILELINE) -.set_attrs_type_key("relay.attrs.PadAttrs") -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Pad", PadRel); - -} // namespace relay -} // namespace tvm +/*! + * Copyright (c) 2018 by Contributors + * \file pad.cc + * \brief Implementation of operator pad + */ +#include +#include +#include +#include +#include "../layout.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(PadAttrs); + +bool PadRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const PadAttrs* param = attrs.as(); + CHECK(param != nullptr); + + // check that pad widths match lengths + CHECK(data->shape.size() == param->pad_width.size()) + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; + + // each pad width element should be a pair of positive integers + std::vector oshape; + for (size_t i = 0; i < param->pad_width.size(); i++) { + CHECK(param->pad_width[i].size() == 2) + << "Each pad width element should be a pair but at index " << i + << " there are " << param->pad_width[i].size() << " elements."; + + auto width1 = as_const_int(param->pad_width[i][0]); + auto width2 = as_const_int(param->pad_width[i][1]); + CHECK(width1 != nullptr); + CHECK(width2 != nullptr); + + CHECK(*width1 >= 0) + << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) + << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; + + auto padding = make_const(data->shape[i].type(), *width1 + *width2); + oshape.push_back(data->shape[i] + padding); + } + + reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + data->dtype)); + return true; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakePad(Expr data, Array > pad_width, double pad_value) { + auto attrs = make_node(); + attrs->pad_value = pad_value; + attrs->pad_width = std::move(pad_width); + static const Op& op = Op::Get("nn.pad"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.pad") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakePad, args, rv); + }); + +RELAY_REGISTER_OP("nn.pad") +.describe(R"code(Pad for n-D tensor. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.PadAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("Pad", PadRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc new file mode 100644 index 000000000000..429e44df4f9d --- /dev/null +++ b/src/relay/pass/alter_op_layout.cc @@ -0,0 +1,65 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alter_op_layout.cc + * \brief Alternate the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ +#include +#include +#include +#include "../op/layout.h" + +namespace tvm { +namespace relay { + +using LayoutMap = std::unordered_map; + +class LayoutCorrector: public ExprMutator { + public: + LayoutCorrector() { + + } + + Expr Correct(Expr expr) { + return expr; + } +}; + +class LayoutAlternator: public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) { + static auto falter_layout = + Op::GetAttr("FTVMAlterOpLayout"); + + Expr new_e = ExprMutator::VisitExpr_(n); + const auto* new_n = new_e.as(); + + if(!new_n->op.as()) + return new_e; + + Op op = Downcast(new_n->op); + + if (falter_layout.count(op)) { + Expr ret = falter_layout[op](new_n->attrs, new_n->args); + if (ret.defined()) + return ret; + } + return new_e; + } +}; + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body([](TVMArgs args, TVMRetValue* ret) { + Expr expr = args[0]; + LayoutCorrector corrector; + + expr = corrector.Correct(expr); + expr = LayoutAlternator().Mutate(expr); + expr = corrector.Correct(expr); + + *ret = expr; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index bcb91e7e5737..c56ee98a3969 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,11 +29,11 @@ using runtime::TypedPackedFunc; // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of -// (value, axes, scale), where the final result satiesfies: +// (value, axes, scale), where the final result satisfies: // // result = value // for i, k in enumerate(axes): -// k-ith dimension of result *= i-th dimension of scale +// k-th dimension of result *= i-th dimension of scale // // Then we can propagate this signal along and fold the scale if necessary. // However, it is possible that certain scale may never be consumed diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py new file mode 100644 index 000000000000..be42b4c9d732 --- /dev/null +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -0,0 +1,70 @@ +"""Test alter op layout pass""" + +from tvm import relay +from tvm.relay.op import register_alter_op_layout +from tvm.relay.ir_pass import alter_op_layout, alpha_equal, infer_type + +def test_alter_op(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + @register_alter_op_layout("nn.conv2d", level=100) + def alter_conv2d(attrs, inputs): + data, weight = inputs + weight = relay.multiply(weight, relay.const(2.0)) + return relay.nn.conv2d(data, weight, **attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0)), + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +def test_alter_return_none(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + y = relay.nn.global_max_pool2d(x) + y = relay.Function([x], y) + return y + + called = [False] + + @register_alter_op_layout("nn.global_max_pool2d", level=101) + def alter_conv2d(attrs, inputs): + called[0] = True + return None + + a = before() + a = alter_op_layout(a) + + b = before() + assert(alpha_equal(a, b)) + assert(called[0]) + + +if __name__ == "__main__": + test_alter_op() + test_alter_return_none() From f55b89ef9da075ba3e2e8c63981b14b13fee824e Mon Sep 17 00:00:00 2001 From: Mercy Date: Mon, 26 Nov 2018 09:23:31 -0800 Subject: [PATCH 02/10] [RELAY] AlterOpLayout Pass --- include/tvm/relay/attrs/nn.h | 5 + include/tvm/relay/attrs/transform.h | 13 + include/tvm/relay/expr.h | 2 +- include/tvm/relay/op_attr_types.h | 13 +- include/tvm/relay/pass.h | 13 + python/tvm/attrs.py | 4 +- python/tvm/relay/build_module.py | 8 + python/tvm/relay/ir_pass.py | 17 ++ python/tvm/relay/op/_tensor.py | 9 - python/tvm/relay/op/_transform.py | 16 +- python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/transform.py | 22 ++ src/relay/op/layout.h | 11 + src/relay/op/nn/convolution.cc | 26 +- src/relay/op/nn/nn.cc | 19 +- src/relay/op/nn/pad.cc | 1 + src/relay/op/nn/pooling.cc | 40 ++- src/relay/op/op_common.h | 57 ++-- src/relay/op/tensor/binary.cc | 40 ++- src/relay/op/tensor/transform.cc | 135 ++++++++- src/relay/op/tensor/unary.cc | 83 +++--- src/relay/pass/alter_op_layout.cc | 268 ++++++++++++++++-- src/relay/pass/alter_op_layout.h | 67 +++++ src/relay/pass/forward_rewrite.cc | 39 ++- src/relay/pass/pattern_util.h | 2 +- src/relay/pass/simplify_bias_add.cc | 46 +++ .../python/relay/test_pass_alter_op_layout.py | 199 ++++++++++++- topi/include/topi/nn.h | 1 + 28 files changed, 1000 insertions(+), 161 deletions(-) create mode 100644 src/relay/pass/alter_op_layout.h create mode 100644 src/relay/pass/simplify_bias_add.cc diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 817ee04bd844..724749368aa9 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { int groups; std::string data_layout; std::string weight_layout; + std::string out_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { @@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 3e56106df0c2..7e614a8cafd4 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode { } }; + +struct LayoutTransformAttrs : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout) + .describe("The source layout of the tensor. (e.g. NCHW)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. NCHW16c)"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 469b73a1df10..37c91ffe4ed2 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const { static_assert(std::is_base_of::value, "TType must be a special case of type"); CHECK(checked_type_.defined()) - << "Type inference for this Expr has not completed"; + << "Type inference for this Expr has not completed. Try to call infer_type pass."; const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 22cf892565c7..1f37e9947bb8 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -88,15 +88,18 @@ using FTVMSchedule = runtime::TypedPackedFunc< /*! * \brief Alternate the layout of operators or replace the - * operator with other expressions. - * - * \param attrs The attribute of the node. - * \param inputs The arguments of this operator. + * operator with other expressions. This function will be invoked + * in AlterOpLayout pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos An array of placeholders, use for getting the inferred shape + * and dtype of the inputs. * \return new_expr The modified expression. */ using FTVMAlterOpLayout = runtime::TypedPackedFunc< Expr(const Attrs& attrs, - const Array& args)>; + const Array& args, + const Array& tinfos)>; /*! * \brief Forward rewriting rule for a specific op. diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 58e160eb4ac9..298d1f77649f 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -8,6 +8,7 @@ #include #include +#include #include namespace tvm { @@ -173,6 +174,18 @@ Expr ForwardRewrite(const Expr& expr, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * \param expr The expression. + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \return The rewritten expression. + */ +Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr); + + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py index f4132652dfd8..529dbcc14c13 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/attrs.py @@ -8,8 +8,8 @@ class Attrs(NodeBase): """Attribute node, which is mainly use for defining attributes of relay operators. - Used by python registration of compute and schedule function. - Attrs is passed as the first argument to schedule and compute function. + Used by function registered in python side, such as compute, schedule and alter_layout. + Attrs is passed as the first argument to these functions. """ def list_field_info(self): """ Get fields information diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 863ca063137f..7b20f59f091b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -17,6 +17,7 @@ "FoldConstant": 2, "CombineParallelConv2D": 3, "FoldScaleAxis": 3, + "AlterOpLayout": 3, } class BuildConfig(object): @@ -157,6 +158,13 @@ def optimize(func, params=None): if cfg.pass_enabled("FoldConstant"): func = ir_pass.fold_constant(func) + + if cfg.pass_enabled("AlterOpLayout"): + func = ir_pass.infer_type(func) + func = ir_pass.simplify_bias_add(func) + func = ir_pass.infer_type(func) + func = ir_pass.alter_op_layout(func) + return func diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index f3bb586fc211..f58f9ba68370 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -191,6 +191,23 @@ def simplify_inference(expr): return _ir_pass.simplify_inference(expr) +def simplify_bias_add(expr): + """ Simplify the bias_add to expand_dims and broadcast_add. + This can simplify latter layout related passes (e.g. alter_op_layout) + + Parameters + ---------- + e: tvm.relay.Expr + The input Expression + + Returns + ------- + result: tvm.relay.Expr + An expression without bias_add + """ + return _ir_pass.simplify_bias_add(expr) + + def dead_code_elimination(expr): """ Remove expressions which does not effect the program result (dead code). diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 75ea3da8af80..d1035ee047e5 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target): return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] register_schedule("clip", schedule_elemwise) -register_pattern("clip", OpPattern.ELEMWISE) - -# concatenate -@register_compute("concatenate") -def concatenate_compute(attrs, inputs, output_type, target): - return [topi.concatenate(inputs, axis=attrs.axis)] - -register_schedule("concatenate", schedule_injective) -register_pattern("concatenate", OpPattern.INJECTIVE) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 3093032f9e40..7c336221b31c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1,8 +1,10 @@ """Backend compiler related feature registration""" # pylint: disable=invalid-name from __future__ import absolute_import +import topi from . import op as _reg from ._reduce import _schedule_reduce +from .op import schedule_injective, OpPattern schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective @@ -15,10 +17,22 @@ _reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("full", schedule_injective) _reg.register_schedule("full_like", schedule_injective) -_reg.register_schedule("cast", schedule_broadcast) +_reg.register_schedule("cast", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("split", schedule_injective) _reg.register_schedule("take", schedule_injective) _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) + +# layout_transform +_reg.register_schedule("layout_transform", schedule_injective) +_reg.register_pattern("layout_transform", OpPattern.INJECTIVE) + +# concatenate +@_reg.register_compute("concatenate") +def concatenate_compute(attrs, inputs, output_type, target): + return [topi.concatenate(inputs, axis=attrs.axis)] + +_reg.register_schedule("concatenate", schedule_injective) +_reg.register_pattern("concatenate", OpPattern.INJECTIVE) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index af50905295a3..682d56fb9efc 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -7,3 +7,8 @@ class Conv2DAttrs(Attrs): """Attribute of a Convolution Operator""" pass + +@register_relay_attr_node +class GlobalPool2DAttrs(Attrs): + """Attribute of a Global 2D Pooling Operator""" + pass diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c5fedab054d2..17caad4bb304 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None): The computed result. """ return _make.slice_like(data, shape_like, axes) + + +def layout_transform(data, src_layout, dst_layout): + """Transform the layout of a tensor + + Parameters + ---------- + data : relay.Expr + The source tensor to be transformed + + src_layout: str + The source layout. (e.g NCHW) + + dst_layout: str + The destination layout. (e.g. NCHW16c) + + Returns + ------- + ret : relay.Expr + The transformed tensor. + """ + return _make.layout_transform(data, src_layout, dst_layout) diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h index 97160f3cbb9e..85c39ec9b24f 100644 --- a/src/relay/op/layout.h +++ b/src/relay/op/layout.h @@ -327,6 +327,17 @@ class Layout : public NodeRef { return operator->()->name == rhs->name; } + /*! + * \brief allow output string of layout to ostream + * \param os the output stream + * \param l the layout + * \return the ostream + */ + friend std::ostream& operator<<(std::ostream& os, const Layout& l) { + os << l.name(); + return os; + } + using ContainerType = LayoutNode; private: diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index cb648166f7bb..b937a83c2f3a 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -7,11 +7,13 @@ #include #include +#include "../../pass/alter_op_layout.h" #include "../layout.h" namespace tvm { namespace relay { +// relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); bool Conv2DRel(const Array& types, @@ -101,6 +103,15 @@ bool Conv2DRel(const Array& types, return true; } +template +Array > Conv2DInferCorrectLayout(const Attrs& attrs, + const Array& in_layouts) { + const T* params = attrs.as(); + Layout out_layout(params->out_layout); + + return Array >{{params->data_layout, params->weight_layout}, + {out_layout.defined() ? out_layout : params->data_layout}}; +} // Positional relay function to create conv2d operator // used by frontend FFI. @@ -156,10 +167,11 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel); +.add_type_rel("Conv2D", Conv2DRel) +.set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); -// Conv2DTranspose +// relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); bool Conv2DTransposeRel(const Array& types, @@ -185,6 +197,12 @@ bool Conv2DTransposeRel(const Array& types, << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; + Layout out_layout(param->out_layout); + if (!out_layout.defined()) out_layout = in_layout; + CHECK(out_layout.Convertible(kNCHW)) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); @@ -241,7 +259,7 @@ bool Conv2DTransposeRel(const Array& types, if (out_dtype.bits() == 0) { out_dtype = data->dtype; } - oshape = ConvertLayout(oshape, kNCHW, in_layout); + oshape = ConvertLayout(oshape, kNCHW, out_layout); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); return true; } @@ -307,6 +325,8 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", + Conv2DInferCorrectLayout) .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); } // namespace relay diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d3b454f35ede..7ed43d0df019 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -12,12 +12,14 @@ #include #include #include "../type_relations.h" +#include "../../pass/alter_op_layout.h" #include "../op_common.h" #include "../layout.h" namespace tvm { namespace relay { +// relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); bool BiasAddRel(const Array& types, @@ -74,6 +76,7 @@ RELAY_REGISTER_OP("nn.bias_add") .add_type_rel("BiasAdd", BiasAddRel); +// relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); @@ -143,6 +146,8 @@ RELAY_REGISTER_OP("nn.dense") .set_support_level(1) .add_type_rel("Dense", DenseRel); +// relay.leaky_relu +TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. Expr MakeLeakyRelu(Expr data, @@ -171,6 +176,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") .add_argument("data", "Tensor", "Input data.") .set_support_level(3) .add_type_rel("Identity", IdentityRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr( "FTVMCompute", [](const Attrs& attrs, const Array& inputs, @@ -181,6 +187,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") }); +// relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); bool PReluRel(const Array& types, @@ -235,6 +242,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr( "FTVMCompute", [](const Attrs& attrs, const Array& inputs, @@ -245,6 +253,9 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. }); +// relay.softmax +TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); + TVM_REGISTER_API("relay.op.nn._make.softmax") .set_body([](const TVMArgs& args, TVMRetValue* rv) { auto make_func = [](Expr data, int axis) { @@ -282,6 +293,7 @@ RELAY_REGISTER_OP("nn.softmax") }); +// relay.nn.log_softmax TVM_REGISTER_API("relay.op.nn._make.log_softmax") .set_body([](const TVMArgs& args, TVMRetValue* rv) { auto make_func = [](Expr data, int axis) { @@ -321,8 +333,7 @@ RELAY_REGISTER_OP("nn.log_softmax") }); - -// BatchFlatten +// relay.nn.batch_flatten bool BatchFlattenRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -410,6 +421,7 @@ RELAY_REGISTER_OP("nn.relu") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -460,6 +472,7 @@ centered at that value (zero padding is added where necessary). .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); @@ -495,6 +508,7 @@ Normalizes along dimension axis using an L2 norm .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); // Dropout @@ -538,6 +552,7 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input .set_num_inputs(1) .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_support_level(1) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Dropout", DropoutRel); // batch_norm diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 8f3e758726fd..5403d0620e50 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -12,6 +12,7 @@ namespace tvm { namespace relay { +// relay.nn.pad TVM_REGISTER_NODE_TYPE(PadAttrs); bool PadRel(const Array& types, diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0af0bbf63633..a68b984f3081 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -9,13 +9,40 @@ #include #include #include "../layout.h" +#include "../../pass/alter_op_layout.h" namespace tvm { namespace relay { +// relay.nn.max_pool2d & relay.nn.avg_pool2d TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); +template +Array > Pool2DInferCorrectLayout( + const Attrs& attrs, + const Array& in_layouts) { + CHECK_EQ(in_layouts.size(), 1); + + // NOTE: Discard "const" qualifier here. + T *params = const_cast(attrs.as()); + Layout input = in_layouts[0]; + const Layout raw_layout(params->layout); + if (input.defined()) { + CHECK(input.Convertible(raw_layout)); + if (input.Indexof('W') != raw_layout.Indexof('W') || + input.Indexof('H') != raw_layout.Indexof('H') || + input.Contains('w') || input.Contains('h')) { + // if the new layout changes width or height dimension, + // fallback to old layout; + input = raw_layout; + } + params->layout = input.name(); // modify self to follow the input layout + } + + return Array >{{params->layout}, {params->layout}}; +} + template bool Pool2DRel(const Array& types, int num_inputs, @@ -163,6 +190,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("MaxPool2D", Pool2DRel) +.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); @@ -219,9 +247,10 @@ Average pooling operation for one dimensional data. .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("AvgPool2D", Pool2DRel) +.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); -// Global Pool +// relay.nn.global_pool_2d & relay.nn.max_pool_2d TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); bool GlobalPool2DRel(const Array& types, @@ -247,8 +276,9 @@ bool GlobalPool2DRel(const Array& types, const auto hidx = layout.Indexof('H'); const auto widx = layout.Indexof('W'); - std::vector oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); - oshape[hidx] = oshape[widx] = 1; + Array oshape(dshape); + oshape.Set(hidx, 1); + oshape.Set(widx, 1); // assign output type reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); @@ -307,6 +337,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) +.set_attr("FInferCorrectLayout", + Pool2DInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool @@ -338,6 +370,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) +.set_attr("FInferCorrectLayout", + Pool2DInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); } // namespace relay diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 5bb2f24cae81..36cd04931903 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -11,6 +11,7 @@ #include #include #include +#include "../pass/alter_op_layout.h" namespace tvm { namespace relay { @@ -32,21 +33,24 @@ inline std::vector AsVector(const Array &array) { * We make the decision to always only expose positional argument. * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. - * - * \param Prefix the prefix of the registry, for example, "relay.op._make.". - * + * \param OpName the name of registry. */ -#define RELAY_REGISTER_UNARY_OP(Prefix, OpName) \ - TVM_REGISTER_API(Prefix OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .set_attr("TOpPattern", kElemWise) +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + ElemwiseArbitraryLayout) \ + /*! Quick helper macro * - Expose a positional make function to construct the node. @@ -56,12 +60,10 @@ inline std::vector AsVector(const Array &array) { * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. * - * \param Prefix the prefix of the registry, for example, "relay.op._make.". - * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(Prefix, OpName) \ - TVM_REGISTER_API(Prefix OpName) \ +#define RELAY_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ @@ -72,7 +74,26 @@ inline std::vector AsVector(const Array &array) { .add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_type_rel("Broadcast", BroadcastRel) \ .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) + +// Comparisons +#define RELAY_REGISTER_CMP_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 3f28bd52cd4b..da9b1af87578 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -23,71 +23,65 @@ namespace relay { // Addition -RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") +RELAY_REGISTER_BINARY_OP("add") .describe("Elementwise add with with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction -RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") +RELAY_REGISTER_BINARY_OP("subtract") .describe("Elementwise substract with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift -RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") +RELAY_REGISTER_BINARY_OP("right_shift") .describe("Elementwise right shift with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift") + +RELAY_REGISTER_BINARY_OP("left_shift") .describe("Elementwise left shift with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum") + +RELAY_REGISTER_BINARY_OP("maximum") .describe("Elementwise maximum of two tensors with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum") + +RELAY_REGISTER_BINARY_OP("minimum") .describe("Elementwise minimum of two tensors with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide") + +RELAY_REGISTER_BINARY_OP("divide") .describe("Elementwise divide with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") + +RELAY_REGISTER_BINARY_OP("multiply") .describe("Elementwise multiply with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "power") + +RELAY_REGISTER_BINARY_OP("power") .describe("Elementwise power with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") + +RELAY_REGISTER_BINARY_OP("mod") .describe("Elementwise mod with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); -// Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) RELAY_REGISTER_CMP_OP("equal") .describe("Elementwise equal compare with broadcasting") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4a052881d7bf..8b38fcb864fa 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -11,9 +11,11 @@ #include #include #include +#include #include #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" +#include "../layout.h" namespace tvm { namespace relay { @@ -156,6 +158,7 @@ RELAY_REGISTER_OP("expand_dims") .set_attr("FTVMCompute", ExpandDimsCompute) .set_attr("TOpPattern", kBroadcast); +// relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); bool ConcatenateRel(const Array& types, @@ -323,7 +326,6 @@ RELAY_REGISTER_OP("transpose") .set_attr("TOpPattern", kInjective); /* relay.reshape */ - TVM_REGISTER_NODE_TYPE(ReshapeAttrs); bool ReshapeRel(const Array& types, @@ -1252,7 +1254,7 @@ Examples:: .set_attr("TOpPattern", kInjective); -// Split +// relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); bool SplitRel(const Array& types, @@ -1367,6 +1369,7 @@ the entries indicate where along axis the array is split. .set_attr("TOpPattern", kInjective); +// relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); /*! @@ -1513,5 +1516,133 @@ RELAY_REGISTER_OP("slice_like") .set_attr("FTVMCompute", SliceLikeCompute) .set_attr("TOpPattern", kInjective); + +// relay.layout_transform +std::pair RemoveLeadingReduandantDimensions( + const Layout &src_layout, const Layout &dst_layout, size_t keep_size) { + // For example, when broadcasting (1, 64, 16, 16) with (64, 1, 1), + // we can still apply rule `NCHW -> NCHW16c` to the right tensor, + // by deleting the leading redundant dimension "N" and apply normal "CHW -> CHW16c". + CHECK_GE(src_layout.ndim(), keep_size) + << "Apply a " << src_layout.ndim() << "-dimensional rule " << src_layout + << " to " << keep_size << "-dimensional tensor"; + int n_remove = src_layout.ndim() - keep_size; + CHECK_GT(dst_layout.ndim(), n_remove); + for (int i = 0; i < n_remove; ++i) { + CHECK_EQ(src_layout[i], dst_layout[i]) + << "Can only delete the same dimension during layout transform"; + CHECK(Layout::IsSuperdim(src_layout[i])) + << "Can only delete a super dimension during layout transform"; + CHECK_EQ(src_layout.Subsizeof(src_layout[i]), -1) + << "Cannot delete a layout dimension with sub_dimension > 0"; + CHECK_EQ(dst_layout.Subsizeof(dst_layout[i]), -1) + << "Cannot delete a layout dimension with sub_dimension > 0"; + } + return std::make_pair(Layout(src_layout.name().substr(n_remove)), + Layout(dst_layout.name().substr(n_remove))); +} + +Array LayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const LayoutTransformAttrs *param = attrs.as(); + CHECK(param != nullptr); + + Layout src_layout(param->src_layout); + Layout dst_layout(param->dst_layout); + + if (src_layout.Equals(dst_layout)) { + return Array{ inputs[0] }; + } + + CHECK(src_layout.defined() && dst_layout.defined()) + << "cannot convert from/to undefined layout"; + CHECK(src_layout.Convertible(dst_layout)) + << "cannot convert from " << param->src_layout << " to " << param->dst_layout; + + std::tie(src_layout, dst_layout) = RemoveLeadingReduandantDimensions( + src_layout, dst_layout, inputs[0]->shape.size()); + + const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout); + return Array { + topi::layout_transform(inputs[0], out_shape, [&](const Array& dst_indices) { + std::vector dst_to_src_indices; + for (size_t i = 0; i < src_layout.ndim(); ++i) { + Layout::LayoutDim src_axis = src_layout[i]; + int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis)); + int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis)); + int32_t src_factor = static_cast(src_layout.Subsizeof(src_axis)); + int32_t dst_factor = static_cast(dst_layout.Subsizeof(src_axis)); + + tvm::Expr src_index(dst_indices[dst_major_pos]); + if (dst_minor_pos >= 0) { + CHECK_GT(dst_factor, 0); + src_index = src_index * dst_factor + dst_indices[dst_minor_pos]; + } + if (Layout::IsSuperdim(src_axis) && src_factor > 0) { + src_index = src_index / src_factor; + } else if (Layout::IsSubdim(src_axis) && src_factor > 0) { + src_index = src_index % src_factor; + } + dst_to_src_indices.push_back(src_index); + } + return Array(dst_to_src_indices); + }) + }; +} + +bool LayoutTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + CHECK(data != nullptr); + const LayoutTransformAttrs* params = attrs.as(); + + Layout src_layout(params->src_layout); + Layout dst_layout(params->dst_layout); + + CHECK(src_layout.defined() && dst_layout.defined()) + << "cannot convert from/to undefined layout"; + CHECK(src_layout.Convertible(dst_layout)) + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + std::tie(src_layout, dst_layout) = RemoveLeadingReduandantDimensions( + src_layout, dst_layout, data->shape.size()); + + const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout); + reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); + return true; +} + +Expr MakeLayoutTransform(Expr data, + std::string src_layout, + std::string dst_layout) { + auto attrs = make_node(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("layout_transform"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.layout_transform") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayoutTransform, args, rv); +}); + +RELAY_REGISTER_OP("layout_transform") +.describe(R"code(Transform the input data layout. + +For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes +the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.LayoutTransformAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("layout_transform", LayoutTransformRel) +.set_support_level(5) +.set_attr("FTVMCompute", LayoutTransformCompute); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index fef0302a0507..b83fdacda1ee 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -22,7 +22,7 @@ namespace relay { } \ -RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") +RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. .. math:: @@ -30,11 +30,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") +RELAY_REGISTER_UNARY_OP("exp") .describe(R"code(Returns the exp input array, computed element-wise. .. math:: @@ -42,36 +41,30 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") -.describe(R"code(Returns the sqrt input array, computed element-wise. +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the rsqrt input array, computed element-wise. .. math:: sqrt(x) )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") +RELAY_REGISTER_UNARY_OP("zeros_like") .describe(R"code(Returns an array of zeros, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.set_support_level(4); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like") +RELAY_REGISTER_UNARY_OP("ones_like") .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.set_support_level(4); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") +RELAY_REGISTER_UNARY_OP("sigmoid") .describe(R"code(Returns the sigmoid input array, computed element-wise. .. math:: @@ -79,48 +72,47 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") +RELAY_REGISTER_UNARY_OP("copy") .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") - .set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_node(); - attrs->a_min = a_min; - attrs->a_max = a_max; - static const Op& op = Op::Get("clip"); - return CallNode::make(op, {a}, Attrs(attrs), {}); - }); +.set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_node(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); + return CallNode::make(op, {a}, Attrs(attrs), {}); +}); RELAY_REGISTER_OP("clip") - .describe(R"code(Clip tensor values. - This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. - )code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The input tensor.") - .set_support_level(3) - .add_type_rel("Clip", IdentityRel); - +.describe(R"code(Clip tensor values. +This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kElemWise) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) +.set_support_level(3); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") +RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") +RELAY_REGISTER_UNARY_OP("ceil") .describe(R"code(Returns the ceil of input array, computed element-wise. .. math:: @@ -128,11 +120,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") +RELAY_REGISTER_UNARY_OP("trunc") .describe(R"code(Returns the trunc of input array, computed element-wise. .. math:: @@ -140,11 +131,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") +RELAY_REGISTER_UNARY_OP("round") .describe(R"code(Returns the round of input array, computed element-wise. .. math:: @@ -152,11 +141,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") +RELAY_REGISTER_UNARY_OP("abs") .describe(R"code(Returns the abs of input array, computed element-wise. .. math:: @@ -164,11 +152,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") +RELAY_REGISTER_UNARY_OP("tanh") .describe(R"code(Returns the tanh of input array, computed element-wise. .. math:: @@ -176,11 +163,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") +RELAY_REGISTER_UNARY_OP("negative") .describe(R"code(Returns the numeric negative of input array, computed element-wise. .. math:: @@ -188,7 +174,6 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); } // namespace relay diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 429e44df4f9d..369f8c622638 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -6,60 +6,268 @@ custom layouts or other general weight pre-transformation. */ #include -#include #include -#include "../op/layout.h" +#include +#include +#include +#include +#include +#include + +#include "alter_op_layout.h" namespace tvm { namespace relay { -using LayoutMap = std::unordered_map; +namespace alter_op_layout { + +// Make a transform CallNode +Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { + if (src_layout.Equals(dst_layout)) + return raw; + CHECK(src_layout.defined() && dst_layout.defined()) + << "Cannot insert layout transform because there are undefined layouts"; + CHECK(src_layout.Convertible(dst_layout)) + << "Cannot insert layout transform because there are inconvertible layouts: " + << src_layout << " v.s. " << dst_layout; + static auto &transform_op = Op::Get("layout_transform"); + NodePtr attrs = make_node(); + attrs->src_layout = src_layout.name(); + attrs->dst_layout = dst_layout.name(); + Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs}); + return transform; +} + +// Memorize layout transform so we can reuse internal transformed nodes +class TransformMemorizerNode : public Node { + public: + using TransformKey = std::tuple; + struct key_hash : public std::unary_function { + std::size_t operator()(const TransformKey& k) const { + return std::hash()(std::get<0>(k)) ^ + std::hash()(std::get<1>(k)) ^ + std::hash()(std::get<2>(k)); + } + }; + + std::unordered_map memo; + static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode"; + TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); +}; -class LayoutCorrector: public ExprMutator { +class TransformMemorizer : public NodeRef { public: - LayoutCorrector() { + TransformMemorizer() {} + explicit TransformMemorizer(NodePtr n) : NodeRef(n) {} + TransformMemorizerNode* operator->() { + return static_cast(node_.get()); } - Expr Correct(Expr expr) { - return expr; + // Transform layout with memorizer + Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { + std::tuple key = + std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); + auto& memo = operator->()->memo; + + auto iter = memo.find(key); + if (iter != memo.end()) { + return iter->second; + } else { + Expr transform = TransformLayout(raw, src_layout, dst_layout); + memo[key] = transform; + return transform; + } } + + using ContainerType = TransformMemorizerNode; }; -class LayoutAlternator: public ExprMutator { + +// TempExprNode during layout transform +// Instance of this expr will be Realized to normal expr ultimately +class LayoutAlternatedExprNode : public TempExprNode { public: - Expr VisitExpr_(const CallNode* n) { - static auto falter_layout = - Op::GetAttr("FTVMAlterOpLayout"); + Expr value; + Layout old_layout; + Layout new_layout; + TransformMemorizer memorizer; - Expr new_e = ExprMutator::VisitExpr_(n); - const auto* new_n = new_e.as(); + Expr Realize() const final { + // NOTE: use a copy to discard the "const" qualifier + TransformMemorizer tmp_memorizer = memorizer; + // fallback to old layout + return tmp_memorizer.Transform(value, new_layout, old_layout); + } - if(!new_n->op.as()) - return new_e; + void VisitAttrs(AttrVisitor *v) final { + v->Visit("value", &value); + v->Visit("old_layout", &old_layout); + v->Visit("new_layout", &new_layout); + } + + static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; + TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); +}; - Op op = Downcast(new_n->op); +RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); + +// Call FInferCorrectLayout of an op. +// Return inferred_input_layout, inferred_output_layout, success +std::tuple, Array, bool> CallInfer( + const Call& call, + const Array& inputs) { + static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); + + Op op = Downcast(call->op); + if (finfer_layout.count(op)) { + Array > inferred_layouts; + inferred_layouts = finfer_layout[op](call->attrs, inputs); + CHECK_EQ(inferred_layouts.size(), 2) + << "FInferCorrectLayout should return an array with size of 2"; + return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true); + } else { + return std::make_tuple<>(Array(nullptr), Array(nullptr), false); + } +} - if (falter_layout.count(op)) { - Expr ret = falter_layout[op](new_n->attrs, new_n->args); - if (ret.defined()) - return ret; +// Call registered FTVMAlterOpLayout of an op +// Return altered expression +Call CallAlter(const Call& ref_call, + const std::vector& new_args) { + static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + Op op = Downcast(ref_call->op); + + Expr new_e; + bool modified = false; + if (falter_layout.count(op)) { + tvm::Array tinfos; + for (auto expr : ref_call->args) { + auto ttype = expr->type_as(); + tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); + } + Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); + if (altered_value.defined()) { + new_e = altered_value; + modified = true; } - return new_e; } -}; + if (!modified) { + new_e = CallNode::make(ref_call->op, new_args, + ref_call->attrs, ref_call->type_args); + } -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Expr expr = args[0]; - LayoutCorrector corrector; + const CallNode *new_call = new_e.as(); + CHECK(new_call) << "Can only replace the original operator with another call node"; + return GetRef(new_call); +} - expr = corrector.Correct(expr); - expr = LayoutAlternator().Mutate(expr); - expr = corrector.Correct(expr); +Expr AlterOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + std::vector inputs; + std::vector normal_new_args; - *ret = expr; + // NOTE: discard the "const" qualifier + TransformMemorizer memorizer = Downcast(ctx); + + // fill incomplete state + for (auto arg : new_args) { + if (const LayoutAlternatedExprNode *inp = arg.as()) { + inputs.push_back(GetRef(inp)); + normal_new_args.push_back(inp->value); + } else { + auto inode = make_node(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr(inode)); + normal_new_args.push_back(arg); + } + } + + // old_in, new_in = state[inputs] + Array old_in, old_out, new_in, new_out, new_in2; + for (auto inp : inputs) { + old_in.push_back(inp->old_layout); + new_in.push_back(inp->new_layout); + } + + // old_in, old_out = op.infer(old_in) + bool success = false; + std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in); + if (!success) { return Expr(nullptr); } + CHECK_EQ(old_in.size(), new_in.size()); + + // if new_in == 'undef': new_in = old_in + for (size_t i = 0; i < new_in.size(); ++i) { + if (!new_in[i].defined()) { + new_in.Set(i, old_in[i]); + } + } + + // new_op = alter(op) + Call new_call = CallAlter(ref_call, normal_new_args); + + // new_in2, new_out = op.infer(new_in) + if (new_call->op->is_type()) { + success = false; + std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in); + if (!success) { return Expr(nullptr); } + } else { + return Expr(nullptr); + } + + CHECK_EQ(new_out.size(), old_out.size()) + << "The number of output nodes should keep the same during alter_op_layout"; + CHECK_EQ(new_in.size(), new_in2.size()) + << "The number of input nodes should keep the same during alter_op_layout"; + + // if (new_in != new_in2): insert transform (new_in -> new_in2) + Array transformed_args; + for (size_t i = 0; i < inputs.size(); ++i) { + transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i])); + } + + // state[node] = (old_out, new_out) + CHECK(ref_call->checked_type_.defined()) + << "Call infer_type pass before alter_op_layout pass"; + + if (ref_call->checked_type()->is_type()) { + Expr tuple_output = CallNode::make(new_call->op, transformed_args, + new_call->attrs, new_call->type_args); + Array fields; + for (size_t i = 0; i < new_out.size(); ++i) { + auto rnode = make_node(); + rnode->value = TupleGetItemNode::make(tuple_output, i); + rnode->old_layout = old_out[i]; + rnode->new_layout = new_out[i]; + rnode->memorizer = memorizer; + fields.push_back(Expr(rnode)); + } + return TupleNode::make(fields); + } else { + auto rnode = make_node(); + CHECK_EQ(new_out.size(), 1); + rnode->value = CallNode::make(new_call->op, transformed_args, + new_call->attrs, new_call->type_args); + rnode->old_layout = old_out[0]; + rnode->new_layout = new_out[0]; + rnode->memorizer = memorizer; + return Expr(rnode); + } +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body([](TVMArgs args, TVMRetValue *ret) { + TransformMemorizer transformMemorizer(make_node()); + auto fcontext = [&](const Call& call) -> NodeRef{ + return transformMemorizer; + }; + + *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); }); +} // namespace alter_op_layout + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h new file mode 100644 index 000000000000..a7305edb3c2a --- /dev/null +++ b/src/relay/pass/alter_op_layout.h @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alter_op_layout.h + * \brief Alternate the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ + +#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ +#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ + +#include + +#include "../op/layout.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Infer & correct function of node layout. See \p Layout for layout convention + * \param attrs The attribute of the node. + * \param input_layout The input layouts. + * \return infered_layout An array of two elements that are inferred input layouts and + * inferred output layouts. + */ +using FInferCorrectLayout = runtime::TypedPackedFunc< + Array>(const Attrs& attrs, + const Array& in_layouts)>; + +/*! \brief take arbitrary input layout and copy to output */ +inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, + const Array& in_layouts) { + Array inferred_ins; + + Layout in; + for (size_t i = 0; i < in_layouts.size(); ++i) { + if (!in.defined()) in = in_layouts[i]; + CHECK(in.Equals(in_layouts[i])) + << "Incompatible layout at " << i << "-th input: expected " << in + << ", got " << in_layouts[i]; + } + for (size_t i = 0; i < in_layouts.size(); ++i) { + inferred_ins.push_back(in); + } + + return Array >{inferred_ins, {in}}; +} + +/*! \brief Infer layout for binary broadcast operators. Prior to keep left layout */ +inline Array > BinaryBroadcastLayout(const Attrs& attrs, + const Array& in_layouts) { + CHECK_EQ(in_layouts.size(), 2); + Layout lhs = in_layouts[0]; + Layout rhs = in_layouts[1]; + + // prior to keep left layout + if (!lhs.defined()) { + lhs = rhs; + } + + return Array > {{lhs, lhs}, {lhs}}; +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 7873db80c6b0..a0cbc4a502c5 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -42,13 +42,20 @@ class TempRealizer : private ExprMutator { class ForwardRewriter : private ExprMutator { public: - ForwardRewriter(const OpMap& rewrite_map, + ForwardRewriter(const OpMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) { - } + fmulti_ref_trigger_(fmulti_ref_trigger) {} + + ForwardRewriter(const FForwardRewrite* rewrite_func, + std::function fcontext, + std::function fmulti_ref_trigger) + : rewrite_func_(rewrite_func), + fcontext_(fcontext), + fmulti_ref_trigger_(fmulti_ref_trigger) {} + // Transform expression. Expr Rewrite(Expr expr) { @@ -60,8 +67,9 @@ class ForwardRewriter : private ExprMutator { private: // The rewrite rule. - const OpMap& rewrite_map_; - // The context. + const OpMap* rewrite_map_{nullptr}; + const FForwardRewrite* rewrite_func_{nullptr}; + // The context.const std::function fcontext_{nullptr}; // The multiple reference trigger std::function fmulti_ref_trigger_{nullptr}; @@ -106,7 +114,13 @@ class ForwardRewriter : private ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { const Call& ref_call = GetRef(call_node); - PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr); + PackedFunc frewrite; + if (rewrite_func_) { + frewrite = *rewrite_func_; + } else { + CHECK(rewrite_map_); + frewrite = rewrite_map_->get(call_node->op, nullptr); + } auto new_op = this->Mutate(call_node->op); bool unchanged = call_node->op.same_as(new_op); @@ -147,9 +161,16 @@ Expr ForwardRewrite(const Expr& expr, std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); - return ForwardRewriter(rewrite_map, - fcontext, - fmulti_ref_trigger).Rewrite(expr); + return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } + +Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext, + std::function fmulti_ref_trigger) { + return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); +} + + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 38ae923c5274..e6e8415bd620 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -73,7 +73,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, * the target Tensor on the specified axis via broadcasting rule. * * \param bias The bias. - * \param target_ndim target dimension. + * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ inline Expr ExpandBiasToMatchAxis(Expr bias, diff --git a/src/relay/pass/simplify_bias_add.cc b/src/relay/pass/simplify_bias_add.cc new file mode 100644 index 000000000000..fd810dc37a7b --- /dev/null +++ b/src/relay/pass/simplify_bias_add.cc @@ -0,0 +1,46 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expand_bias_add.cc + * \brief Expand bias_add to expand_dims and broadcast_add. + * This can simplify the passes related to layout (e.g. alter_op_layout). + */ +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class BiasAddSimplifier : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) { + static const Op& bias_add = Op::Get("nn.bias_add"); + auto new_n = ExprMutator::VisitExpr_(n); + if (n->op.same_as(bias_add)) { + Call call = Downcast(new_n); + CHECK_EQ(call->args.size(), 2); + const BiasAddAttrs* param = call->attrs.as(); + + auto ttype = call->args[0]->type_as(); + size_t n_dim = ttype->shape.size(); + Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis}); + Expr ret = Add(call->args[0], expanded_bias); + ret->checked_type_ = n->checked_type_; + return ret; + } + return new_n; + } +}; + +Expr SimplifyBiasAdd(const Expr& e) { + return BiasAddSimplifier().Mutate(e); +} + +TVM_REGISTER_API("relay._ir_pass.simplify_bias_add") +.set_body([](TVMArgs args, TVMRetValue* ret) { +*ret = SimplifyBiasAdd(args[0]); +}); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index be42b4c9d732..1c9d3f258601 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -2,9 +2,10 @@ from tvm import relay from tvm.relay.op import register_alter_op_layout -from tvm.relay.ir_pass import alter_op_layout, alpha_equal, infer_type +from tvm.relay.ir_pass import * def test_alter_op(): + """Test directly replacing an operator with a new one""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) weight = relay.var('weight', shape=(64, 64, 3, 3)) @@ -17,7 +18,7 @@ def before(): return y @register_alter_op_layout("nn.conv2d", level=100) - def alter_conv2d(attrs, inputs): + def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0)) return relay.nn.conv2d(data, weight, **attrs) @@ -44,6 +45,7 @@ def expected(): def test_alter_return_none(): + """Test doing nothing by returning 'None' """ def before(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.nn.global_max_pool2d(x) @@ -53,18 +55,209 @@ def before(): called = [False] @register_alter_op_layout("nn.global_max_pool2d", level=101) - def alter_conv2d(attrs, inputs): + def alter_conv2d(attrs, inputs, tinfos): called[0] = True return None a = before() + a = infer_type(a) a = alter_op_layout(a) b = before() + b = infer_type(b) assert(alpha_equal(a, b)) assert(called[0]) +def test_alter_layout(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias") + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.bias_add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.batch_flatten(y) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=102) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + new_attrs['weight_layout'] = 'OIHW16i' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OIHW16i") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + weight_layout="OIHW16i", + data_layout="NCHW16c") + b = relay.expand_dims(bias, axis=1, num_newaxis=2) + b = relay.layout_transform(b, "NCHW", "NCHW16c") + y = relay.add(y, b) + + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = simplify_bias_add(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +def test_alter_layout_dual_path(): + """ + Test alternating the layout with two outputs. + One path continues to use the new layout while one path fall backs to old layout. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y2 = relay.nn.batch_flatten(y) + ret = relay.Tuple([y1, y2]) + y = relay.Function(free_vars(ret), ret) + return y + + @register_alter_op_layout("nn.conv2d", level=103) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW16c') + y1 = relay.nn.relu(y1) + y1 = relay.layout_transform(y1, "NCHW16c", "NCHW") + y2 = relay.layout_transform(y, "NCHW16c", "NCHW") + y2 = relay.nn.batch_flatten(y2) + ret = relay.Tuple([y1, y2]) + y = relay.Function(free_vars(ret), ret) + return y + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + +def test_alter_layout_resnet(): + """Test alternating the layout of a residual block + This also tests the elimination of duplicated transformation. + If a same transformation applies to a same node twice, only one transformation will be created. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y) + return relay.Function(free_vars(y), y) + + @register_alter_op_layout("nn.conv2d", level=104) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + x = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1), + data_layout='NCHW16c') + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + return relay.Function(free_vars(y), y) + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + if __name__ == "__main__": test_alter_op() test_alter_return_none() + test_alter_layout() + test_alter_layout_dual_path() + test_alter_layout_resnet() diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 5fc05162f09b..9d3e675d8ef7 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -448,6 +448,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, } using FLayoutIndicesTransform = std::function(const Array& indices)>; + /*! * \brief Transform the layout according to the mapping function \p to_src_indices. * \param src the source input. From 79727963fc367b2085c9fce43a9e7a739003ad96 Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 27 Nov 2018 01:04:22 -0800 Subject: [PATCH 03/10] fix broadcast operators --- python/tvm/relay/op/_transform.py | 3 +- src/relay/op/layout.h | 12 +++- src/relay/op/nn/convolution.cc | 3 +- src/relay/op/nn/pooling.cc | 3 +- src/relay/op/tensor/transform.cc | 29 -------- src/relay/pass/alter_op_layout.cc | 26 ++++++-- src/relay/pass/alter_op_layout.h | 66 +++++++++++++++---- .../python/relay/test_pass_alter_op_layout.py | 55 +++++++++++++++- 8 files changed, 147 insertions(+), 50 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 7c336221b31c..a6b89f093d87 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1,5 +1,5 @@ """Backend compiler related feature registration""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-argument from __future__ import absolute_import import topi from . import op as _reg @@ -11,6 +11,7 @@ _reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("broadcast_to_like", schedule_broadcast) +_reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("reshape", schedule_injective) diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h index 85c39ec9b24f..94dc9cd2a486 100644 --- a/src/relay/op/layout.h +++ b/src/relay/op/layout.h @@ -185,7 +185,7 @@ class Layout : public NodeRef { CHECK_GT(block_size, 0); new_layout << block_size; } - new_layout << layout_simplified[i]->value; + new_layout << static_cast(layout_simplified[i]->value); } return Layout(new_layout.str()); } @@ -241,6 +241,16 @@ class Layout : public NodeRef { return operator->()->layout_simplified.size(); } + /*! \return number of super dimension */ + size_t ndim_super() const { + size_t ct = 0; + for (auto x : operator->()->layout_simplified) { + if (IsSuperdim(x)) + ct++; + } + return ct; + } + /*! * \brief The description of the \p i-th dimension. * If it is a sub-dimension, the size will be returned as well, diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index b937a83c2f3a..65fb09a3e842 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -105,7 +105,8 @@ bool Conv2DRel(const Array& types, template Array > Conv2DInferCorrectLayout(const Attrs& attrs, - const Array& in_layouts) { + const Array& in_layouts, + const Array> &in_shapes) { const T* params = attrs.as(); Layout out_layout(params->out_layout); diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index a68b984f3081..bcc329bdbc11 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,7 +21,8 @@ TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template Array > Pool2DInferCorrectLayout( const Attrs& attrs, - const Array& in_layouts) { + const Array& in_layouts, + const Array> &in_shapes) { CHECK_EQ(in_layouts.size(), 1); // NOTE: Discard "const" qualifier here. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 8b38fcb864fa..0d552973baa7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1518,30 +1518,6 @@ RELAY_REGISTER_OP("slice_like") // relay.layout_transform -std::pair RemoveLeadingReduandantDimensions( - const Layout &src_layout, const Layout &dst_layout, size_t keep_size) { - // For example, when broadcasting (1, 64, 16, 16) with (64, 1, 1), - // we can still apply rule `NCHW -> NCHW16c` to the right tensor, - // by deleting the leading redundant dimension "N" and apply normal "CHW -> CHW16c". - CHECK_GE(src_layout.ndim(), keep_size) - << "Apply a " << src_layout.ndim() << "-dimensional rule " << src_layout - << " to " << keep_size << "-dimensional tensor"; - int n_remove = src_layout.ndim() - keep_size; - CHECK_GT(dst_layout.ndim(), n_remove); - for (int i = 0; i < n_remove; ++i) { - CHECK_EQ(src_layout[i], dst_layout[i]) - << "Can only delete the same dimension during layout transform"; - CHECK(Layout::IsSuperdim(src_layout[i])) - << "Can only delete a super dimension during layout transform"; - CHECK_EQ(src_layout.Subsizeof(src_layout[i]), -1) - << "Cannot delete a layout dimension with sub_dimension > 0"; - CHECK_EQ(dst_layout.Subsizeof(dst_layout[i]), -1) - << "Cannot delete a layout dimension with sub_dimension > 0"; - } - return std::make_pair(Layout(src_layout.name().substr(n_remove)), - Layout(dst_layout.name().substr(n_remove))); -} - Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -1561,9 +1537,6 @@ Array LayoutTransformCompute(const Attrs& attrs, CHECK(src_layout.Convertible(dst_layout)) << "cannot convert from " << param->src_layout << " to " << param->dst_layout; - std::tie(src_layout, dst_layout) = RemoveLeadingReduandantDimensions( - src_layout, dst_layout, inputs[0]->shape.size()); - const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout); return Array { topi::layout_transform(inputs[0], out_shape, [&](const Array& dst_indices) { @@ -1607,8 +1580,6 @@ bool LayoutTransformRel(const Array& types, << "cannot convert from/to undefined layout"; CHECK(src_layout.Convertible(dst_layout)) << "cannot convert from " << params->src_layout << " to " << params->dst_layout; - std::tie(src_layout, dst_layout) = RemoveLeadingReduandantDimensions( - src_layout, dst_layout, data->shape.size()); const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout); reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 369f8c622638..78632942877c 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -116,15 +116,23 @@ RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); // Return inferred_input_layout, inferred_output_layout, success std::tuple, Array, bool> CallInfer( const Call& call, - const Array& inputs) { + const Array& in_layouts, + const Array>& in_shapes) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array > inferred_layouts; - inferred_layouts = finfer_layout[op](call->attrs, inputs); + inferred_layouts = finfer_layout[op](call->attrs, in_layouts, in_shapes); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; + for (auto x : inferred_layouts) { + for (auto y : x) { + if (!y.defined()) { // inference fails + return std::make_tuple<>(Array(nullptr), Array(nullptr), false); + } + } + } return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true); } else { return std::make_tuple<>(Array(nullptr), Array(nullptr), false); @@ -167,6 +175,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, const NodeRef& ctx) { std::vector inputs; std::vector normal_new_args; + Array> input_shapes; // NOTE: discard the "const" qualifier TransformMemorizer memorizer = Downcast(ctx); @@ -192,9 +201,13 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, new_in.push_back(inp->new_layout); } + for (auto arg : ref_call->args) { + input_shapes.push_back(arg->type_as()->shape); + } + // old_in, old_out = op.infer(old_in) bool success = false; - std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in); + std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in, input_shapes); if (!success) { return Expr(nullptr); } CHECK_EQ(old_in.size(), new_in.size()); @@ -211,7 +224,12 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // new_in2, new_out = op.infer(new_in) if (new_call->op->is_type()) { success = false; - std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in); + for (size_t i = 0; i < input_shapes.size(); ++i) { + if (old_in.defined()) { + input_shapes.Set(i, ConvertLayout(input_shapes[i], old_in[i], new_in[i])); + } + } + std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, input_shapes); if (!success) { return Expr(nullptr); } } else { return Expr(nullptr); diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h index a7305edb3c2a..b5de670bdb57 100644 --- a/src/relay/pass/alter_op_layout.h +++ b/src/relay/pass/alter_op_layout.h @@ -19,17 +19,20 @@ namespace relay { /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention * \param attrs The attribute of the node. - * \param input_layout The input layouts. + * \param in_layouts The layouts of input arguments. + * \param in_shapes The shapes of input arguments. * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ using FInferCorrectLayout = runtime::TypedPackedFunc< Array>(const Attrs& attrs, - const Array& in_layouts)>; + const Array& in_layouts, + const Array> &in_shapes)>; /*! \brief take arbitrary input layout and copy to output */ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& in_layouts) { + const Array& in_layouts, + const Array > &in_shapes) { Array inferred_ins; Layout in; @@ -46,19 +49,58 @@ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, return Array >{inferred_ins, {in}}; } -/*! \brief Infer layout for binary broadcast operators. Prior to keep left layout */ +/*! \brief Infer layout for binary broadcast operators */ inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& in_layouts) { + const Array& in_layouts, + const Array > &in_shapes) { CHECK_EQ(in_layouts.size(), 2); - Layout lhs = in_layouts[0]; - Layout rhs = in_layouts[1]; + CHECK_EQ(in_shapes.size(), 2); - // prior to keep left layout - if (!lhs.defined()) { - lhs = rhs; - } + Array layouts = in_layouts; + + if (!layouts[0].defined() && !layouts[1].defined()) { + // both undefined, infer fails + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } else if (!layouts[0].defined() || !layouts[1].defined()) { + // only one is defined, use shape information to help infer + int defined_idx = layouts[0].defined() ? 0 : 1; + int undef_idx = 1 - defined_idx; + + if (in_shapes[defined_idx].size() >= in_shapes[undef_idx].size()) { + layouts.Set(undef_idx, + layouts[defined_idx].Sublayout( + in_shapes[defined_idx].size() - in_shapes[undef_idx].size(), + in_shapes[undef_idx].size())); + return Array > {layouts, {layouts[defined_idx]}}; + } else { + // only know the tensor with smaller dimensions, + // so we cannot infer the final broadcasted output. + // fails in this case. + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } else { + // try to broadcast to the tensors to the larger dimension + int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; + int small_idx = 1 - large_idx; + Layout ret = layouts[large_idx]; - return Array > {{lhs, lhs}, {lhs}}; + // extract common part + size_t i = layouts[large_idx].ndim(); + for (; i != 0; --i) { + auto dim = layouts[large_idx][i-1]; + if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) { + break; + } + } + + Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i); + if (!layouts[small_idx].Convertible(common_part)) { // fail + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + + layouts.Set(small_idx, common_part); + return Array > {layouts, {ret}}; + } } } // namespace relay diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 1c9d3f258601..159200e24056 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -108,7 +108,7 @@ def expected(): weight_layout="OIHW16i", data_layout="NCHW16c") b = relay.expand_dims(bias, axis=1, num_newaxis=2) - b = relay.layout_transform(b, "NCHW", "NCHW16c") + b = relay.layout_transform(b, "CHW", "CHW16c") y = relay.add(y, b) y = relay.nn.relu(y) @@ -255,9 +255,62 @@ def expected(): assert(alpha_equal(a, b)) + +def test_alter_layout_broadcast_op(): + """Test boradcast operators """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + scale = relay.var("scale", shape=(64, 1, 1)) + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.bias_add(y, bias) # test broadcasting to lhs + y = relay.multiply(scale, y) # test broadcasting to rhs + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=102) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + scale = relay.var("scale", shape=(64, 1, 1)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW16c") + bias = relay.expand_dims(bias, 1, 2) + bias = relay.layout_transform(bias, "CHW", "CHW16c") + scale = relay.layout_transform(scale, "CHW", "CHW16c") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW16c") + y = relay.add(y, bias) # test broadcasting to lhs + y = relay.multiply(scale, y) # test broadcasting to rhs + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = simplify_bias_add(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() test_alter_layout() test_alter_layout_dual_path() test_alter_layout_resnet() + test_alter_layout_broadcast_op() + From 627f568a590f2ea6cfae18eae37235aee2598cbe Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 27 Nov 2018 01:17:31 -0800 Subject: [PATCH 04/10] fix broadcast operators --- src/relay/op/layout.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h index 94dc9cd2a486..90c920bf3aa1 100644 --- a/src/relay/op/layout.h +++ b/src/relay/op/layout.h @@ -241,7 +241,7 @@ class Layout : public NodeRef { return operator->()->layout_simplified.size(); } - /*! \return number of super dimension */ + /*! \return number of super dimensions */ size_t ndim_super() const { size_t ct = 0; for (auto x : operator->()->layout_simplified) { From d9188755b0dbc1d79b4f06f770c7907407a8485c Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 27 Nov 2018 01:24:20 -0800 Subject: [PATCH 05/10] fix broadcast operators --- include/tvm/relay/pass.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 298d1f77649f..8fff7016a827 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -179,11 +179,14 @@ Expr ForwardRewrite(const Expr& expr, * \param expr The expression. * \param rewrite_func The rewrite func that will apply to all operators. * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. * \return The rewritten expression. */ Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr); + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); /*! \brief A hashing structure in the style of std::hash. */ From dc098ca6cadffcc517ed7be5dc48d45f5972ec75 Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 27 Nov 2018 22:24:05 -0800 Subject: [PATCH 06/10] Support concatenate --- 3rdparty/HalideIR | 2 +- src/relay/op/nn/convolution.cc | 10 +++-- src/relay/op/nn/pooling.cc | 28 +++++++------- src/relay/op/tensor/transform.cc | 40 ++++++++++++++++++- src/relay/pass/alter_op_layout.cc | 61 +++++++++++++++++++---------- src/relay/pass/alter_op_layout.h | 64 ++++++++++++++++++------------- src/relay/pass/forward_rewrite.cc | 16 ++++++++ 7 files changed, 153 insertions(+), 68 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index e4a4c02764d3..a08e26e5a97f 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit e4a4c02764d37c9c3db0d64c4996651a3ef9513c +Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 65fb09a3e842..170b6b6d13c5 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -104,12 +104,16 @@ bool Conv2DRel(const Array& types, } template -Array > Conv2DInferCorrectLayout(const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes) { +Array > Conv2DInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { const T* params = attrs.as(); Layout out_layout(params->out_layout); + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs return Array >{{params->data_layout, params->weight_layout}, {out_layout.defined() ? out_layout : params->data_layout}}; } diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index bcc329bdbc11..6233e6d51776 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,24 +21,22 @@ TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template Array > Pool2DInferCorrectLayout( const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes) { - CHECK_EQ(in_layouts.size(), 1); - + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { // NOTE: Discard "const" qualifier here. T *params = const_cast(attrs.as()); - Layout input = in_layouts[0]; - const Layout raw_layout(params->layout); - if (input.defined()) { - CHECK(input.Convertible(raw_layout)); - if (input.Indexof('W') != raw_layout.Indexof('W') || - input.Indexof('H') != raw_layout.Indexof('H') || - input.Contains('w') || input.Contains('h')) { - // if the new layout changes width or height dimension, - // fallback to old layout; - input = raw_layout; + + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.Indexof('W') == raw_layout.Indexof('W') && + input.Indexof('H') == raw_layout.Indexof('H') && + !input.Contains('w') && !input.Contains('h')) { + params->layout = input.name(); // modify self to follow the input layout } - params->layout = input.name(); // modify self to follow the input layout } return Array >{{params->layout}, {params->layout}}; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0d552973baa7..fcf7f6fe3299 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -15,6 +15,7 @@ #include #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" +#include "../../pass/alter_op_layout.h" #include "../layout.h" namespace tvm { @@ -204,6 +205,42 @@ bool ConcatenateRel(const Array& types, return true; } +Array> ConcatenateLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + const ConcatenateAttrs* param = attrs.as(); + + size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : + static_cast(param->axis); + + Layout ret; + if (new_in_layouts.defined()) { // this function is called after some operators are alternated. + Layout::LayoutDim concate_dim = old_in_layouts[0][axis]; + for (size_t i = 0; i < new_in_layouts.size(); ++i) { + if (new_in_layouts[i].ndim() > axis && + new_in_layouts[i][axis] == concate_dim) { + ret = new_in_layouts[i]; + break; + } + } + } else { // this function is called on the original correct relay ir + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } + + if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) { + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } + + return Array > {Array(old_in_layouts.size(), ret), {ret}}; +} + Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_node(); @@ -229,7 +266,8 @@ RELAY_REGISTER_OP("concatenate") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel); +.add_type_rel("Concatenate", ConcatenateRel) +.set_attr("FInferCorrectLayout", ConcatenateLayout); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 78632942877c..b3f7a478dcc3 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -66,6 +66,9 @@ class TransformMemorizer : public NodeRef { // Transform layout with memorizer Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { + if (src_layout.Equals(dst_layout)) + return raw; + std::tuple key = std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); auto& memo = operator->()->memo; @@ -116,14 +119,16 @@ RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); // Return inferred_input_layout, inferred_output_layout, success std::tuple, Array, bool> CallInfer( const Call& call, - const Array& in_layouts, - const Array>& in_shapes) { + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array > inferred_layouts; - inferred_layouts = finfer_layout[op](call->attrs, in_layouts, in_shapes); + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, + old_in_layouts, old_in_shapes); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { @@ -180,17 +185,27 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // NOTE: discard the "const" qualifier TransformMemorizer memorizer = Downcast(ctx); - // fill incomplete state - for (auto arg : new_args) { - if (const LayoutAlternatedExprNode *inp = arg.as()) { - inputs.push_back(GetRef(inp)); - normal_new_args.push_back(inp->value); + // fill incomplete state and expand tuple + for (auto new_arg : new_args) { + auto push_back_one_arg = [&](Expr arg) { + if (const LayoutAlternatedExprNode *inp = arg.as()) { + inputs.push_back(GetRef(inp)); + normal_new_args.push_back(inp->value); + } else { + auto inode = make_node(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr(inode)); + normal_new_args.push_back(arg); + } + }; + if (new_arg->is_type()) { + Tuple tuple_new_arg = Downcast(new_arg); + for (auto x : tuple_new_arg->fields) { + push_back_one_arg(x); + } } else { - auto inode = make_node(); - inode->value = arg; - inode->memorizer = memorizer; - inputs.push_back(LayoutAlternatedExpr(inode)); - normal_new_args.push_back(arg); + push_back_one_arg(new_arg); } } @@ -202,12 +217,21 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } for (auto arg : ref_call->args) { - input_shapes.push_back(arg->type_as()->shape); + if (arg->is_type()) { // expand tuple + Tuple tuple_arg = Downcast(arg); + for (auto x : tuple_arg->fields) { + input_shapes.push_back(x->type_as()->shape); + } + } else { + input_shapes.push_back(arg->type_as()->shape); + } } // old_in, old_out = op.infer(old_in) bool success = false; - std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in, input_shapes); + std::tie(old_in, old_out, success) = CallInfer(ref_call, + Array(nullptr), + old_in, input_shapes); if (!success) { return Expr(nullptr); } CHECK_EQ(old_in.size(), new_in.size()); @@ -224,12 +248,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // new_in2, new_out = op.infer(new_in) if (new_call->op->is_type()) { success = false; - for (size_t i = 0; i < input_shapes.size(); ++i) { - if (old_in.defined()) { - input_shapes.Set(i, ConvertLayout(input_shapes[i], old_in[i], new_in[i])); - } - } - std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, input_shapes); + std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes); if (!success) { return Expr(nullptr); } } else { return Expr(nullptr); diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h index b5de670bdb57..fcb7b379a0ec 100644 --- a/src/relay/pass/alter_op_layout.h +++ b/src/relay/pass/alter_op_layout.h @@ -19,44 +19,54 @@ namespace relay { /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention * \param attrs The attribute of the node. - * \param in_layouts The layouts of input arguments. - * \param in_shapes The shapes of input arguments. + * \param new_in_layouts The layouts of input arguments after alter_op_layout. + * This can be undefined, which means we call this function before alternating + * any operators. + * \param old_in_layouts The layouts of input arguments before alter_op_layout. + * \param old_in_shapes The shapes of old input arguments. * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ using FInferCorrectLayout = runtime::TypedPackedFunc< Array>(const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes)>; + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes)>; /*! \brief take arbitrary input layout and copy to output */ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& in_layouts, - const Array > &in_shapes) { - Array inferred_ins; - - Layout in; - for (size_t i = 0; i < in_layouts.size(); ++i) { - if (!in.defined()) in = in_layouts[i]; - CHECK(in.Equals(in_layouts[i])) - << "Incompatible layout at " << i << "-th input: expected " << in - << ", got " << in_layouts[i]; - } - for (size_t i = 0; i < in_layouts.size(); ++i) { - inferred_ins.push_back(in); + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Layout ret; + + if (new_in_layouts.defined()) { + CHECK_GE(new_in_layouts.size(), 1); + ret = new_in_layouts[0]; + } else { + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } } - return Array >{inferred_ins, {in}}; + return Array >{Array(old_in_layouts.size(), ret), {ret}}; } /*! \brief Infer layout for binary broadcast operators */ inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& in_layouts, - const Array > &in_shapes) { - CHECK_EQ(in_layouts.size(), 2); - CHECK_EQ(in_shapes.size(), 2); + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Array layouts; - Array layouts = in_layouts; + if (new_in_layouts.defined()) { + layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); + } else { + layouts.assign(old_in_layouts.begin(), old_in_layouts.end()); + } if (!layouts[0].defined() && !layouts[1].defined()) { // both undefined, infer fails @@ -66,11 +76,11 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, int defined_idx = layouts[0].defined() ? 0 : 1; int undef_idx = 1 - defined_idx; - if (in_shapes[defined_idx].size() >= in_shapes[undef_idx].size()) { + if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { layouts.Set(undef_idx, layouts[defined_idx].Sublayout( - in_shapes[defined_idx].size() - in_shapes[undef_idx].size(), - in_shapes[undef_idx].size())); + old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); return Array > {layouts, {layouts[defined_idx]}}; } else { // only know the tensor with smaller dimensions, @@ -79,7 +89,7 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, return Array > {{Layout::Undef()}, {Layout::Undef()}}; } } else { - // try to broadcast to the tensors to the larger dimension + // try to broadcast the tensors to the larger dimension int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; int small_idx = 1 - large_idx; Layout ret = layouts[large_idx]; diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index a0cbc4a502c5..4f33d4a053b7 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -112,6 +112,22 @@ class ForwardRewriter : private ExprMutator { } } + Expr VisitExpr_(const TupleNode* op) final { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->GetTempExpr(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return GetRef(op); + } else { + return TupleNode::make(fields); + } + } + Expr VisitExpr_(const CallNode* call_node) final { const Call& ref_call = GetRef(call_node); PackedFunc frewrite; From 70ac9d3285d99d2b28634547cf0da86f2453870d Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 28 Nov 2018 18:17:24 -0800 Subject: [PATCH 07/10] address comments --- python/tvm/relay/build_module.py | 2 +- python/tvm/relay/ir_pass.py | 8 ++++---- src/relay/pass/alter_op_layout.cc | 17 +++++++++-------- ...simplify_bias_add.cc => canonicalize_ops.cc} | 12 ++++++------ tests/python/relay/test_pass_alter_op_layout.py | 4 ++-- 5 files changed, 22 insertions(+), 21 deletions(-) rename src/relay/pass/{simplify_bias_add.cc => canonicalize_ops.cc} (77%) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 7b20f59f091b..2a2cd9f82ecb 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -161,7 +161,7 @@ def optimize(func, params=None): if cfg.pass_enabled("AlterOpLayout"): func = ir_pass.infer_type(func) - func = ir_pass.simplify_bias_add(func) + func = ir_pass.canonicalize_ops(func) func = ir_pass.infer_type(func) func = ir_pass.alter_op_layout(func) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index f58f9ba68370..53fa59cd053d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -191,9 +191,9 @@ def simplify_inference(expr): return _ir_pass.simplify_inference(expr) -def simplify_bias_add(expr): - """ Simplify the bias_add to expand_dims and broadcast_add. - This can simplify latter layout related passes (e.g. alter_op_layout) +def canonicalize_ops(expr): + """ Canonicalize special operators to basic operators. + This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) Parameters ---------- @@ -205,7 +205,7 @@ def simplify_bias_add(expr): result: tvm.relay.Expr An expression without bias_add """ - return _ir_pass.simplify_bias_add(expr) + return _ir_pass.canonicalize_ops(expr) def dead_code_elimination(expr): diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index b3f7a478dcc3..23e5ec45bb56 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -23,8 +23,7 @@ namespace alter_op_layout { // Make a transform CallNode Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { - if (src_layout.Equals(dst_layout)) - return raw; + if (src_layout.Equals(dst_layout)) { return raw; } CHECK(src_layout.defined() && dst_layout.defined()) << "Cannot insert layout transform because there are undefined layouts"; CHECK(src_layout.Convertible(dst_layout)) @@ -41,12 +40,12 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { // Memorize layout transform so we can reuse internal transformed nodes class TransformMemorizerNode : public Node { public: + // map from (Expr, src_layout, dst_layout) to transformed Expr using TransformKey = std::tuple; struct key_hash : public std::unary_function { std::size_t operator()(const TransformKey& k) const { - return std::hash()(std::get<0>(k)) ^ - std::hash()(std::get<1>(k)) ^ - std::hash()(std::get<2>(k)); + return dmlc::HashCombine(dmlc::HashCombine( + std::hash()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k))); } }; @@ -66,8 +65,7 @@ class TransformMemorizer : public NodeRef { // Transform layout with memorizer Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { - if (src_layout.Equals(dst_layout)) - return raw; + if (src_layout.Equals(dst_layout)) { return raw; } std::tuple key = std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); @@ -180,7 +178,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, const NodeRef& ctx) { std::vector inputs; std::vector normal_new_args; - Array> input_shapes; + Array > input_shapes; // NOTE: discard the "const" qualifier TransformMemorizer memorizer = Downcast(ctx); @@ -188,6 +186,8 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // fill incomplete state and expand tuple for (auto new_arg : new_args) { auto push_back_one_arg = [&](Expr arg) { + // We always expect LayoutAlternatedExpr. + // This is used to convert the normal expr into LayoutAlternatedExpr. if (const LayoutAlternatedExprNode *inp = arg.as()) { inputs.push_back(GetRef(inp)); normal_new_args.push_back(inp->value); @@ -199,6 +199,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, normal_new_args.push_back(arg); } }; + if (new_arg->is_type()) { Tuple tuple_new_arg = Downcast(new_arg); for (auto x : tuple_new_arg->fields) { diff --git a/src/relay/pass/simplify_bias_add.cc b/src/relay/pass/canonicalize_ops.cc similarity index 77% rename from src/relay/pass/simplify_bias_add.cc rename to src/relay/pass/canonicalize_ops.cc index fd810dc37a7b..77cd59e2afd8 100644 --- a/src/relay/pass/simplify_bias_add.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -1,8 +1,8 @@ /*! * Copyright (c) 2018 by Contributors - * \file expand_bias_add.cc - * \brief Expand bias_add to expand_dims and broadcast_add. - * This can simplify the passes related to layout (e.g. alter_op_layout). + * \file canonicalize_ops.cc + * \brief Canonicalize special operators to basic operators. + This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ #include #include @@ -33,13 +33,13 @@ class BiasAddSimplifier : public ExprMutator { } }; -Expr SimplifyBiasAdd(const Expr& e) { +Expr CanonicalizeOps(const Expr& e) { return BiasAddSimplifier().Mutate(e); } -TVM_REGISTER_API("relay._ir_pass.simplify_bias_add") +TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") .set_body([](TVMArgs args, TVMRetValue* ret) { -*ret = SimplifyBiasAdd(args[0]); +*ret = CanonicalizeOps(args[0]); }); } // namespace relay diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 159200e24056..6a8be7ea847e 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -119,7 +119,7 @@ def expected(): a = before() a = infer_type(a) - a = simplify_bias_add(a) + a = canonicalize_ops(a) a = infer_type(a) a = alter_op_layout(a) a = infer_type(a) @@ -295,7 +295,7 @@ def expected(): a = before() a = infer_type(a) - a = simplify_bias_add(a) + a = canonicalize_ops(a) a = infer_type(a) a = alter_op_layout(a) a = infer_type(a) From 1f10fcdf050625125b6b3aad9b305f6cb1138e10 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 28 Nov 2018 18:19:51 -0800 Subject: [PATCH 08/10] address comments --- src/relay/pass/alter_op_layout.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 23e5ec45bb56..7e34864fefc6 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -113,7 +113,7 @@ class LayoutAlternatedExprNode : public TempExprNode { RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); -// Call FInferCorrectLayout of an op. +// Call registered FInferCorrectLayout of an op. // Return inferred_input_layout, inferred_output_layout, success std::tuple, Array, bool> CallInfer( const Call& call, @@ -187,7 +187,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, for (auto new_arg : new_args) { auto push_back_one_arg = [&](Expr arg) { // We always expect LayoutAlternatedExpr. - // This is used to convert the normal expr into LayoutAlternatedExpr. + // This is used to convert the normal Expr to LayoutAlternatedExpr. if (const LayoutAlternatedExprNode *inp = arg.as()) { inputs.push_back(GetRef(inp)); normal_new_args.push_back(inp->value); From e9cf2bb8190c2d45586f1c6029779383cc8c0dc9 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 28 Nov 2018 18:26:55 -0800 Subject: [PATCH 09/10] add comments --- src/relay/pass/alter_op_layout.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 7e34864fefc6..5c4475259086 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -114,12 +114,13 @@ class LayoutAlternatedExprNode : public TempExprNode { RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); // Call registered FInferCorrectLayout of an op. -// Return inferred_input_layout, inferred_output_layout, success +// Parameters are the same as the parameters for FInferCorrectLayout +// Returns inferred_input_layout, inferred_output_layout, success std::tuple, Array, bool> CallInfer( const Call& call, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array > &old_in_shapes) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); Op op = Downcast(call->op); @@ -143,7 +144,7 @@ std::tuple, Array, bool> CallInfer( } // Call registered FTVMAlterOpLayout of an op -// Return altered expression +// Returns the altered expression Call CallAlter(const Call& ref_call, const std::vector& new_args) { static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); From 8849f9a16af9aae13c74ca6cce340cc3ee4f3da4 Mon Sep 17 00:00:00 2001 From: Mercy Date: Thu, 29 Nov 2018 14:34:37 -0800 Subject: [PATCH 10/10] rebase --- python/tvm/relay/op/_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a6b89f093d87..1aaf376a7dc8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -11,7 +11,6 @@ _reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("broadcast_to_like", schedule_broadcast) -_reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("reshape", schedule_injective)