diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 724749368aa9..103359e3617c 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -38,7 +38,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { IndexExpr channels; Array kernel_size; std::string data_layout; - std::string weight_layout; + std::string kernel_layout; std::string out_layout; DataType out_dtype; @@ -68,7 +68,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(weight_layout).set_default("OIHW") + TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "dimensions respectively."); @@ -84,13 +84,85 @@ struct Conv2DAttrs : public tvm::AttrsNode { } }; + +/*! \brief Attributes used in winograd weight transformation operators */ +struct Conv2DWinogradWeightTransformAttrs : + public tvm::AttrsNode { + int tile_size; + + TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs, + "relay.attrs.Conv2DWinogradWeightTransformAttrs") { + TVM_ATTR_FIELD(tile_size) + .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + } +}; + +/*! \brief Attributes used in convolution operators with winograd algorithm */ +struct Conv2DWinogradAttrs : public tvm::AttrsNode { + int tile_size; + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { + TVM_ATTR_FIELD(tile_size) + .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(channels) + .describe("The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(data_layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + + /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("The axis to sum over when computing softmax."); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("The axis to sum over when computing softmax."); } }; @@ -104,7 +176,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { Array dilation; int groups; std::string data_layout; - std::string weight_layout; + std::string kernel_layout; std::string out_layout; DataType out_dtype; @@ -136,7 +208,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(weight_layout).set_default("OIHW") + TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "dimensions respectively."); diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index a168f4fd88d2..8efa401c0fef 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -1,11 +1,12 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument """Convert an NNVM graph to Relay.""" import json +import numpy + from tvm import relay, nd from tvm.relay import op, expr, var from tvm.relay.frontend.common import StrAttrsDict from tvm.relay.frontend.nnvm_common import _rename -import numpy from .symbol import Symbol from .compiler import graph_attr from .graph import create as graph_create @@ -42,7 +43,7 @@ def _conv2d(children, attrs, odtype='float32'): dilation = attrs.get_int_tuple('dilation', (1, 1)) groups = attrs.get_int('groups', 1) data_layout = attrs.get_str('layout', 'NCHW') - weight_layout = attrs.get_str('kernel_layout', 'OIHW') + kernel_layout = attrs.get_str('kernel_layout', 'OIHW') out_layout = '' out_dtype = attrs.get_str('out_dtype', '') @@ -54,7 +55,7 @@ def _conv2d(children, attrs, odtype='float32'): dilation=dilation, groups=groups, data_layout=data_layout, - weight_layout=weight_layout, + kernel_layout=kernel_layout, out_layout=out_layout, out_dtype=out_dtype) @@ -77,7 +78,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'): dilation = attrs.get_int_tuple('dilation', (1, 1)) groups = attrs.get_int('groups', 1) data_layout = attrs.get_str('layout', 'NCHW') - weight_layout = attrs.get_str('kernel_layout', 'OIHW') + kernel_layout = attrs.get_str('kernel_layout', 'OIHW') out_dtype = attrs.get_str('out_dtype', '') out_conv2d = op.nn.conv2d_transpose( @@ -88,7 +89,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'): dilation=dilation, groups=groups, data_layout=data_layout, - weight_layout=weight_layout, + kernel_layout=kernel_layout, out_dtype=out_dtype) if use_bias: diff --git a/nnvm/python/nnvm/top/attr_dict.py b/nnvm/python/nnvm/top/attr_dict.py index efd439fa75fc..834fffdd01c2 100644 --- a/nnvm/python/nnvm/top/attr_dict.py +++ b/nnvm/python/nnvm/top/attr_dict.py @@ -138,7 +138,7 @@ def get_bool(self, key): else: raise ValueError("Wrong bool format for key %s" % key) - def get_string(self, key): + def get_str(self, key): """Get string from attr dict Parameters diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index a37a5d7e071e..3aaafaed62f9 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -153,7 +153,25 @@ def schedule_conv2d(attrs, outs, target): @reg.register_alter_op_layout("conv2d") def alter_conv2d_layout(attrs, inputs, tinfos): - return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos) + """Replace conv2d op with other layouts or algorithms""" + import nnvm.symbol as sym + + # map relay op names to nnvm op names + sym.contrib_conv2d_winograd_without_weight_transform = \ + sym.contrib.conv2d_winograd_without_weight_transform + sym.contrib_conv2d_winograd_weight_transform = \ + sym.contrib.conv2d_winograd_weight_transform + sym.nn = sym + + # map relay argument names to nnvm argument names + raw_reshape = sym.reshape + def _reshape(*args, **kwargs): + if "newshape" in kwargs: + kwargs['shape'] = kwargs.pop('newshape') + return raw_reshape(*args, **kwargs) + sym.reshape = _reshape + + return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, sym) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -166,9 +184,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): dilation = attrs.get_int_tuple("dilation") out_channel = attrs.get_int("channels") groups = attrs.get_int("groups") - layout = attrs.get_string("layout") - out_layout = attrs.get_string("out_layout") - out_dtype = attrs.get_string("out_dtype") + layout = attrs.get_str("layout") + out_layout = attrs.get_str("out_layout") + out_dtype = attrs.get_str("out_dtype") out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype if layout == "NCHW": _, in_channel, _, _ = get_const_tuple(inputs[0].shape) @@ -227,8 +245,8 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _): strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") - layout = attrs.get_string("layout") - out_dtype = attrs.get_string("out_dtype") + layout = attrs.get_str("layout") + out_dtype = attrs.get_str("out_dtype") tile_size = attrs.get_int("tile_size") out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype assert dilation == (1, 1), "Do not support dilate now" @@ -262,7 +280,7 @@ def compute_conv2d_transpose(attrs, inputs, _): strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") - out_dtype = attrs.get_string("out_dtype") + out_dtype = attrs.get_str("out_dtype") layout = attrs["layout"] out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py index 529dbcc14c13..08548d181883 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/attrs.py @@ -33,6 +33,45 @@ def keys(self): for field in fields: yield field.name + def get_int_tuple(self, key): + """Get a python int tuple of a key + + Parameters + ---------- + key: str + + Returns + ------- + value: Tuple of int + """ + return tuple(x.value for x in self.__getattr__(key)) + + def get_int(self, key): + """Get a python int value of a key + + Parameters + ---------- + key: str + + Returns + ------- + value: int + """ + return self.__getattr__(key) + + def get_str(self, key): + """Get a python int value of a key + + Parameters + ---------- + key: str + + Returns + ------- + value: int + """ + return self.__getattr__(key) + def __getitem__(self, item): return self.__getattr__(item) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 5b05bc44551a..01bf1bc25d5e 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -119,7 +119,7 @@ def _bind_params_by_name(func, params): return expr.bind(func, bind_dict) -def optimize(func, params=None): +def optimize(func, target, params=None): """Perform target invariant optimizations. Parameters @@ -127,6 +127,9 @@ def optimize(func, params=None): func : tvm.relay.Function The input to optimization. + target: :any:`tvm.target.Target` + The optimization target. Some optimization passes are target specific. + params : Optional[Dict[str, tvm.nd.NDArray]] Input parameters to the graph that do not change during inference time. used for constant folding. @@ -164,7 +167,11 @@ def optimize(func, params=None): func = ir_pass.infer_type(func) func = ir_pass.canonicalize_ops(func) func = ir_pass.infer_type(func) - func = ir_pass.alter_op_layout(func) + with target: + func = ir_pass.alter_op_layout(func) + + if cfg.pass_enabled("FoldConstant"): + func = ir_pass.fold_constant(func) return func @@ -222,7 +229,7 @@ def build(func, cfg = BuildConfig.current with tophub_context: - func = optimize(func, params) + func = optimize(func, target, params) # Fuse ops before running code gen func = ir_pass.infer_type(func) func = ir_pass.fuse_ops(func, cfg.opt_level) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 7bffbd4f499e..ea49a6642796 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -72,9 +72,9 @@ def _mx_conv2d(inputs, attrs): channel_axis = _get_channel_axis(data_layout, "conv2d") if "kernel_layout" in attrs.attrs: - weight_layout = attrs.get_str("kernel_layout") + kernel_layout = attrs.get_str("kernel_layout") else: - weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" + kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW" new_attrs = {} new_attrs["channels"] = attrs.get_int("num_filter") @@ -84,7 +84,7 @@ def _mx_conv2d(inputs, attrs): new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout - new_attrs["weight_layout"] = weight_layout + new_attrs["kernel_layout"] = kernel_layout use_bias = not attrs.get_bool("no_bias", False) res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs) if use_bias: @@ -103,9 +103,9 @@ def _mx_conv2d_transpose(inputs, attrs): channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") if "kernel_layout" in attrs.attrs: - weight_layout = attrs.get_str("kernel_layout") + kernel_layout = attrs.get_str("kernel_layout") else: - weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" + kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW" new_attrs = {} new_attrs["channels"] = attrs.get_int("num_filter") @@ -116,7 +116,7 @@ def _mx_conv2d_transpose(inputs, attrs): new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout - new_attrs["weight_layout"] = weight_layout + new_attrs["kernel_layout"] = kernel_layout use_bias = not attrs.get_bool("no_bias", False) res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 8180d8b31044..ad67e78c5ac1 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -55,7 +55,7 @@ def compute_conv2d(attrs, inputs, out_type, target): dilation = get_const_tuple(attrs.dilation) groups = attrs.groups layout = attrs.data_layout - weight_layout = attrs.weight_layout + kernel_layout = attrs.kernel_layout out_dtype = attrs.out_dtype out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") else out_dtype) @@ -70,13 +70,13 @@ def compute_conv2d(attrs, inputs, out_type, target): inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype) elif layout == "NCHW" and \ - weight_layout == "OIHW" and \ + kernel_layout == "OIHW" and \ get_const_int(inputs[1].shape[0]) == groups and \ get_const_int(inputs[1].shape[1]) == 1: out = topi.nn.depthwise_conv2d_nchw( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) elif layout == "NHWC" and \ - weight_layout == "HWOI" and\ + kernel_layout == "HWOI" and\ get_const_int(inputs[1].shape[2]) == groups and \ get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( @@ -91,7 +91,7 @@ def schedule_conv2d(attrs, outs, target): """Schedule definition of conv2d""" groups = attrs.groups layout = attrs.data_layout - kernel_layout = attrs.weight_layout + kernel_layout = attrs.kernel_layout with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) @@ -111,7 +111,8 @@ def schedule_conv2d(attrs, outs, target): @reg.register_alter_op_layout("nn.conv2d") def alter_op_layout_conv2d(attrs, inputs, tinfos): """Alternate the layout of conv2d""" - return None + from ... import op + return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -249,7 +250,7 @@ def schedule_l2_normalize(attrs, outs, target): reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) -# Upsampling +# upsampling reg.register_schedule("nn.upsampling", reg.schedule_injective) def schedule_upsampling(_, outs, target): """Schedule definition of upsampling""" @@ -257,3 +258,50 @@ def schedule_upsampling(_, outs, target): return topi.generic.schedule_injective(outs) # pad reg.register_schedule("nn.pad", schedule_broadcast) + +# winograd related operators +@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform") +def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target): + """Compute definition of conv2d_winograd_without_weight_transform""" + # pylint: disable=assignment-from-no-return + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + data_layout = attrs.get_str("data_layout") + out_dtype = attrs.get_str("out_dtype") + tile_size = attrs.get_int("tile_size") + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + assert dilation == (1, 1), "Do not support dilate now" + assert groups == 1, "Do not supoort arbitrary group number" + + out = topi.nn.conv2d_winograd_without_weight_transform( + inputs[0], inputs[1], strides, padding, dilation, data_layout, + out_dtype, tile_size) + + return [out] + +@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") +def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): + """Schedule definition of conv2d_winograd_without_weight_transform""" + with target: + return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) + +reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) + + +@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") +def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): + """Compute definition of contrib_conv2d_winograd_weight_transform""" + out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size')) + return [out] + +@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") +def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): + """Schedule definition of contrib_conv2d_winograd_weight_transform""" + with target: + return topi.generic.schedule_conv2d_winograd_weight_transform(outs) + +reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 63b1e206e72c..0acb656c99ac 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -13,7 +13,7 @@ def conv2d(data, channels=None, kernel_size=None, data_layout="NCHW", - weight_layout="OIHW", + kernel_layout="OIHW", out_layout="", out_dtype=""): r"""2D convolution. @@ -23,7 +23,7 @@ def conv2d(data, In the default case, where the data_layout is `NCHW` - and weight_layout is `OIHW`, conv2d takes in + and kernel_layout is `OIHW`, conv2d takes in a data Tensor with shape `(batch_size, in_channels, height, width)`, and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])` to produce an output Tensor with the following rule: @@ -70,7 +70,7 @@ def conv2d(data, data_layout : str, optional Layout of the input. - weight_layout : str, optional + kernel_layout : str, optional Layout of the weight. out_layout : str, optional @@ -86,7 +86,7 @@ def conv2d(data, """ return _make.conv2d(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, - weight_layout, out_layout, out_dtype) + kernel_layout, out_layout, out_dtype) def conv2d_transpose(data, @@ -98,7 +98,7 @@ def conv2d_transpose(data, channels=None, kernel_size=None, data_layout="NCHW", - weight_layout="OIHW", + kernel_layout="OIHW", output_padding=(0, 0), out_dtype=""): """Two dimensional trnasposed convolution operator. @@ -126,7 +126,7 @@ def conv2d_transpose(data, data_layout : str, optional Layout of the input. - weight_layout : str, optional + kernel_layout : str, optional Layout of the weight. output_padding : Tuple[int], optional @@ -142,7 +142,7 @@ def conv2d_transpose(data, """ return _make.conv2d_transpose(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, - weight_layout, output_padding, out_dtype) + kernel_layout, output_padding, out_dtype) def softmax(data, axis=-1): @@ -765,3 +765,96 @@ def batch_norm(data, center, scale) return TupleWrapper(result, 3) + + +def contrib_conv2d_winograd_without_weight_transform(data, + weight, + tile_size, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="", + out_dtype=""): + r"""2D convolution with winograd algorithm. + + The basic parameters are the same as the ones in vanilla conv2d. + It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_weight_transform + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + tile_size : int + The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) + + strides : tuple of int, optional + The strides of convoltution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_conv2d_winograd_without_weight_transform( + data, weight, tile_size, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + +def contrib_conv2d_winograd_weight_transform(weight, + tile_size): + r"""Weight Transformation part for 2D convolution with winograd algorithm. + + We separate this as a single op to enable pre-compute for inference. + Use this together with nn.contrib_conv2d_winograd_without_weight_transform + + Parameters + ---------- + weight : tvm.relay.Expr + The weight expressions. + + tile_size : int + The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 682d56fb9efc..d6d73242bb96 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -5,10 +5,20 @@ @register_relay_attr_node class Conv2DAttrs(Attrs): - """Attribute of a Convolution Operator""" + """Attribute of nn.conv2d""" + pass + +@register_relay_attr_node +class Conv2DWinogradAttrs(Attrs): + """Attribute of nn.contrib_conv2d_winograd_without_weight_transform""" + pass + +@register_relay_attr_node +class Conv2DWinogradWeightTransformAttrs(Attrs): + """Attribute of nn.contrib_conv2d_winograd_weight_transform""" pass @register_relay_attr_node class GlobalPool2DAttrs(Attrs): - """Attribute of a Global 2D Pooling Operator""" + """Attribute of nn.global_pool""" pass diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index f49013928748..aa4c1dbb8742 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -11,4 +11,6 @@ from . import squeezenet from . import vgg from . import densenet + from .config import ctx_list +from .init import create_workload diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h index 90c920bf3aa1..ff19f8d42e57 100644 --- a/src/relay/op/layout.h +++ b/src/relay/op/layout.h @@ -60,7 +60,7 @@ class Layout : public NodeRef { Layout() : Layout("__undef__") {} // NOLINT(*) /*! \brief construct from a string */ - Layout(const char* str) : Layout(std::string(str)) {} // NOLINT(*) + Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) /*! * \brief construct from a string. @@ -70,11 +70,64 @@ class Layout : public NodeRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& layout) { // NOLINT(*) - if (layout.length() != 0) { - Parse(layout); - } else { - Parse("__undef__"); + Layout(const std::string& name) { // NOLINT(*) + node_ = make_node(); + + std::vector superdim_pos(kUniqueDim, -1); + std::vector subdim_pos(kUniqueDim, -1); + std::vector subdim_size(kUniqueDim, -1); + std::vector layout_simplified; + + if (name != "__undef__") { // parse layout string + int32_t factor = 0; + uint32_t curr = 0; + for (size_t i = 0; i < name.size(); ++i) { + const LayoutDim c = name.at(i); + if (IsSuperdim(c)) { + int pos = c - 'A'; + CHECK_EQ(factor, 0) << "Invalid layout " << name + << ": invalid factor size " << factor + << " before dimension " << c; + CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name + << ": duplicate dimension " << c; + superdim_pos[pos] = curr++; + layout_simplified.push_back(c); + } else if (IsSubdim(c)) { + int pos = c - 'a'; + CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " + << factor << " for dimension " << c; + CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name + << ": duplicate dimension " << c; + CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name + << ": duplicate dimension " << c; + subdim_pos[pos] = curr++; + subdim_size[pos] = factor; + layout_simplified.push_back(c); + factor = 0; + } else if (c >= '0' && c <= '9') { + CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number."; + factor = factor * 10 + c - '0'; + } else { + LOG(FATAL) << "Invalid layout " << name; + } + } + for (LayoutDim dim : layout_simplified) { + CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0) + << "Invalid layout " << name << ": missing axis " + << static_cast(dim - 'a' + 'A'); + } + } + + LayoutNode *node = operator->(); + node->name = name; + + for (uint32_t i = 0; i < kUniqueDim; ++i) { + node->superdim_pos.push_back(superdim_pos[i]); + node->subdim_pos.push_back(subdim_pos[i]); + node->subdim_size.push_back(subdim_size[i]); + } + for (LayoutDim dim : layout_simplified) { + node->layout_simplified.push_back(dim); } } @@ -177,7 +230,6 @@ class Layout : public NodeRef { const Array& layout_simplified = operator->()->layout_simplified; if (pos > ndim()) return Layout::Undef(); if (pos + len > ndim()) len = ndim() - pos; - if (len == 0) return Layout::Undef(); std::ostringstream new_layout; for (size_t i = pos; i < pos + len; ++i) { if (IsSubdim(layout_simplified[i]->value)) { @@ -349,69 +401,6 @@ class Layout : public NodeRef { } using ContainerType = LayoutNode; - - private: - void Parse(const std::string &layout) { - node_ = make_node(); - - std::vector superdim_pos(kUniqueDim, -1); - std::vector subdim_pos(kUniqueDim, -1); - std::vector subdim_size(kUniqueDim, -1); - std::vector layout_simplified; - - if (layout != "__undef__") { // parse layout string - int32_t factor = 0; - uint32_t curr = 0; - for (size_t i = 0; i < layout.size(); ++i) { - const LayoutDim c = layout.at(i); - if (IsSuperdim(c)) { - int pos = c - 'A'; - CHECK_EQ(factor, 0) << "Invalid layout " << layout - << ": invalid factor size " << factor - << " before dimension " << c; - CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - superdim_pos[pos] = curr++; - layout_simplified.push_back(c); - } else if (IsSubdim(c)) { - int pos = c - 'a'; - CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " - << factor << " for dimension " << c; - CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - subdim_pos[pos] = curr++; - subdim_size[pos] = factor; - layout_simplified.push_back(c); - factor = 0; - } else if (c >= '0' && c <= '9') { - CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number."; - factor = factor * 10 + c - '0'; - } else { - LOG(FATAL) << "Invalid layout " << layout; - } - } - CHECK(!layout_simplified.empty()) << "Invalid layout " << layout; - for (LayoutDim dim : layout_simplified) { - CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0) - << "Invalid layout " << layout << ": missing axis " - << static_cast(dim - 'a' + 'A'); - } - } - - LayoutNode *node = operator->(); - node->name = layout; - - for (uint32_t i = 0; i < kUniqueDim; ++i) { - node->superdim_pos.push_back(superdim_pos[i]); - node->subdim_pos.push_back(subdim_pos[i]); - node->subdim_size.push_back(subdim_size[i]); - } - for (LayoutDim dim : layout_simplified) { - node->layout_simplified.push_back(dim); - } - } }; /*! diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 170b6b6d13c5..608cdab2bacb 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -30,7 +30,7 @@ bool Conv2DRel(const Array& types, const Conv2DAttrs* param = attrs.as(); CHECK(param != nullptr); const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->weight_layout); + const Layout kernel_layout(param->kernel_layout); CHECK(in_layout.Convertible(kNCHW)) << "Conv only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -38,8 +38,7 @@ bool Conv2DRel(const Array& types, << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; - Layout out_layout(param->out_layout); - if (!out_layout.defined()) out_layout = in_layout; + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); CHECK(out_layout.Convertible(kNCHW)) << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; @@ -110,12 +109,12 @@ Array > Conv2DInferCorrectLayout( const Array& old_in_layouts, const Array> &old_in_shapes) { const T* params = attrs.as(); - Layout out_layout(params->out_layout); // We always make other operators to fit the layouts of convolution layers // So this inference ignores all inputs - return Array >{{params->data_layout, params->weight_layout}, - {out_layout.defined() ? out_layout : params->data_layout}}; + return Array >{{params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? + params->data_layout : params->out_layout}}; } // Positional relay function to create conv2d operator @@ -129,7 +128,7 @@ Expr MakeConv2D(Expr data, IndexExpr channels, Array kernel_size, std::string data_layout, - std::string weight_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { auto attrs = make_node(); @@ -137,10 +136,10 @@ Expr MakeConv2D(Expr data, attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); attrs->groups = groups; - attrs->channels = channels; - attrs->kernel_size = kernel_size; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); attrs->data_layout = std::move(data_layout); - attrs->weight_layout = std::move(weight_layout); + attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv2d"); @@ -194,7 +193,7 @@ bool Conv2DTransposeRel(const Array& types, const Conv2DTransposeAttrs* param = attrs.as(); CHECK(param != nullptr); const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->weight_layout); + const Layout kernel_layout(param->kernel_layout); CHECK(in_layout.Convertible(kNCHW)) << "Conv only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -202,8 +201,7 @@ bool Conv2DTransposeRel(const Array& types, << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; - Layout out_layout(param->out_layout); - if (!out_layout.defined()) out_layout = in_layout; + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); CHECK(out_layout.Convertible(kNCHW)) << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; @@ -279,19 +277,19 @@ Expr MakeConv2DTranspose(Expr data, IndexExpr channels, Array kernel_size, std::string data_layout, - std::string weight_layout, + std::string kernel_layout, Array output_padding, DataType out_dtype) { auto attrs = make_node(); - attrs->channels = channels; - attrs->kernel_size = kernel_size; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->output_padding = std::move(output_padding); attrs->dilation = std::move(dilation); attrs->groups = groups; attrs->data_layout = std::move(data_layout); - attrs->weight_layout = std::move(weight_layout); + attrs->kernel_layout = std::move(kernel_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv2d_transpose"); return CallNode::make(op, {data, weight}, Attrs(attrs), {}); @@ -334,5 +332,190 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` Conv2DInferCorrectLayout) .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + +// relay.nn.contrib_conv2d_winograd_without_weight_transform +TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); + +bool Conv2DWinogradRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const Conv2DWinogradAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + CHECK(in_layout.Convertible(kNCHW)) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + CHECK(kernel_layout.Convertible(kOIHW)) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + CHECK(out_layout.Convertible(kNCHW)) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + + std::vector dshape_nchw = ConvertLayout( + data->shape, in_layout, kNCHW); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + CHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or infered by previous pass"; + + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + + // NOTE: Do not check weight shape here! + // Different backend requires different layout to compute + // the batch gemm stage in winograd efficiently, but we want to + // make this op work for all backends. + // So we accept all weight shapes, and assume the TOPI developers + // can handle this correctly in alter_op_layout. + + // dilation + std::vector oshape({dshape_nchw[0], channels, 0, 0}); + + oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; + oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = ConvertLayout(oshape, kNCHW, out_layout); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + + +// Positional relay function to create conv2d winograd operator +// used by frontend FFI. +Expr MakeConv2DWinograd(Expr data, + Expr weight, + int tile_size, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->tile_size = tile_size; + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = channels; + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeConv2DWinograd, args, rv); + }); + + +RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") +.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. + This operator assumes the weight tensor is already pre-transformed by + nn.contrib_conv2d_winograd_weight_transform. + +- **data**: Input is 4D array of shape (batch_size, in_channels, height, width) +- **weight**: Any shape + We do not check the shape for this input tensor. Since different backend + has different layout strategy. + +- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.Conv2DWinograd") +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(5) +.add_type_rel("Conv2DWinograd", Conv2DWinogradRel) +.set_attr("FInferCorrectLayout", + Conv2DInferCorrectLayout); + +// relay.nn.contrib_conv2d_winograd_weight_transform +TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs); + +bool Conv2DWinogradWeightTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const Conv2DWinogradWeightTransformAttrs* param = attrs.as(); + CHECK(param != nullptr); + + CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; + + // each pad width element should be a pair of positive integers + std::vector oshape { + param->tile_size + data->shape[2] - 1, + param->tile_size + data->shape[3] - 1, + data->shape[0], + data->shape[1], + }; + + reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + data->dtype)); + return true; +} + +Expr MakeConv2DWinogradWeightTransform(Expr weight, + int tile_size) { + auto attrs = make_node(); + attrs->tile_size = tile_size; + static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform"); + return CallNode::make(op, {weight}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeConv2DWinogradWeightTransform, args, rv); + }); + + +RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") +.describe(R"code(Weight transformation of winograd fast convolution algorithm. + +Separate this into another nnvm symbol in order to enable Precompute Pass to compute the +weight transformation in advance. + +- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.Conv2DWinogradWeightTransformAttrs") +.set_num_inputs(1) +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(5) +.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 5c4475259086..b33d68a174bc 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -166,7 +166,7 @@ Call CallAlter(const Call& ref_call, } if (!modified) { new_e = CallNode::make(ref_call->op, new_args, - ref_call->attrs, ref_call->type_args); + ref_call->attrs); } const CallNode *new_call = new_e.as(); @@ -184,30 +184,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // NOTE: discard the "const" qualifier TransformMemorizer memorizer = Downcast(ctx); - // fill incomplete state and expand tuple - for (auto new_arg : new_args) { - auto push_back_one_arg = [&](Expr arg) { - // We always expect LayoutAlternatedExpr. - // This is used to convert the normal Expr to LayoutAlternatedExpr. - if (const LayoutAlternatedExprNode *inp = arg.as()) { - inputs.push_back(GetRef(inp)); - normal_new_args.push_back(inp->value); - } else { - auto inode = make_node(); - inode->value = arg; - inode->memorizer = memorizer; - inputs.push_back(LayoutAlternatedExpr(inode)); - normal_new_args.push_back(arg); - } - }; + // fill incomplete state and flatten tuple + auto push_back_one_arg = [&inputs, memorizer](Expr arg) { + // We always expect LayoutAlternatedExpr. + // This is used to convert the normal Expr to LayoutAlternatedExpr. + if (const LayoutAlternatedExprNode *inp = arg.as()) { + inputs.push_back(GetRef(inp)); + return inp->value; + } else { + auto inode = make_node(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr(inode)); + return arg; + } + }; + for (auto new_arg : new_args) { + // NOTE: do not support nested tuple if (new_arg->is_type()) { Tuple tuple_new_arg = Downcast(new_arg); + std::vector fields; for (auto x : tuple_new_arg->fields) { - push_back_one_arg(x); + Expr tmp = push_back_one_arg(x); + fields.push_back(tmp); } + normal_new_args.push_back(TupleNode::make(fields)); } else { - push_back_one_arg(new_arg); + Expr tmp = push_back_one_arg(new_arg); + normal_new_args.push_back(tmp); } } @@ -219,7 +224,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } for (auto arg : ref_call->args) { - if (arg->is_type()) { // expand tuple + if (arg->is_type()) { // flatten tuple Tuple tuple_arg = Downcast(arg); for (auto x : tuple_arg->fields) { input_shapes.push_back(x->type_as()->shape); @@ -263,17 +268,30 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // if (new_in != new_in2): insert transform (new_in -> new_in2) Array transformed_args; - for (size_t i = 0; i < inputs.size(); ++i) { - transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i])); + size_t pt = 0; + for (auto arg : new_call->args) { + if (arg->is_type()) { // unflatten tuple + Tuple tuple_arg = Downcast(arg); + std::vector transformed_tuple_arg; + for (auto arg_item : tuple_arg->fields) { + transformed_tuple_arg.push_back( + memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); + pt++; + } + transformed_args.push_back(TupleNode::make(transformed_tuple_arg)); + } else { + transformed_args.push_back( + memorizer.Transform(arg, new_in[pt], new_in2[pt])); + pt++; + } } + CHECK_EQ(pt, inputs.size()); // state[node] = (old_out, new_out) - CHECK(ref_call->checked_type_.defined()) - << "Call infer_type pass before alter_op_layout pass"; - + // (handle tuple output) if (ref_call->checked_type()->is_type()) { Expr tuple_output = CallNode::make(new_call->op, transformed_args, - new_call->attrs, new_call->type_args); + new_call->attrs); Array fields; for (size_t i = 0; i < new_out.size(); ++i) { auto rnode = make_node(); @@ -288,7 +306,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, auto rnode = make_node(); CHECK_EQ(new_out.size(), 1); rnode->value = CallNode::make(new_call->op, transformed_args, - new_call->attrs, new_call->type_args); + new_call->attrs); rnode->old_layout = old_out[0]; rnode->new_layout = new_out[0]; rnode->memorizer = memorizer; @@ -296,6 +314,9 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } } +// Limiations: +// 1. the altered op should have the same number of arguments as the previous one +// 2. do not support nested tuple arguments TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") .set_body([](TVMArgs args, TVMRetValue *ret) { TransformMemorizer transformMemorizer(make_node()); diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index e346aea518e9..cd2d29e80048 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -91,13 +91,13 @@ class BranchGroupFinder : private ExprVisitor { CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as(); const auto* tweight_b = b->args[1]->type_as(); - const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); - const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); + const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW); + const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && eq(attrs_a->data_layout, attrs_b->data_layout) && - eq(attrs_a->weight_layout, attrs_b->weight_layout) && + eq(attrs_a->kernel_layout, attrs_b->kernel_layout) && eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && eq(shape_a[3], shape_b[3]); @@ -159,7 +159,7 @@ class ParallelConv2DCombiner { auto channels = GetConv2DSuperChannelsDim(conv2d); num_filters += channels; } - auto index = branches[0][0]->attrs.as()->weight_layout.find('O'); + auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), MakeConstScalar(Int(32), num_filters)); @@ -182,7 +182,7 @@ class ParallelConv2DCombiner { new_attrs->groups = attrs->groups; new_attrs->kernel_size = attrs->kernel_size; new_attrs->data_layout = attrs->data_layout; - new_attrs->weight_layout = attrs->weight_layout; + new_attrs->kernel_layout = attrs->kernel_layout; new_attrs->out_layout = attrs->out_layout; new_attrs->out_dtype = attrs->out_dtype; new_attrs->channels = new_channels; diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 760a226a2fac..60df5d90a87c 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -384,7 +384,7 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { const auto* param = call->attrs.as(); CHECK(param != nullptr); Layout data_layout(param->data_layout); - Layout weight_layout(param->weight_layout); + Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.Indexof('C'); int c_small_axis = data_layout.Indexof('c'); @@ -397,8 +397,8 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast - bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); - if (weight_layout.Indexof('i') < 0 && + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); + if (kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; @@ -418,19 +418,19 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const auto* param = ref_call->attrs.as(); CHECK(param != nullptr); Layout data_layout(param->data_layout); - Layout weight_layout(param->weight_layout); + Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.Indexof('C'); CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout - CHECK_EQ(weight_layout.Indexof('i'), -1); + CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); - int big_oc_axis = weight_layout.Indexof('O'); - int big_ic_axis = weight_layout.Indexof('I'); + int big_oc_axis = kernel_layout.Indexof('O'); + int big_ic_axis = kernel_layout.Indexof('I'); // Check it must be depthwise or full conv2d. - bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); + bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); Expr weight = new_args[1]; @@ -438,11 +438,11 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // match the ic_axis if (is_depthwise_conv2d) { Expr scale = ExpandBiasToMatchAxis( - sdata->scale, weight_layout.ndim(), {big_oc_axis}); + sdata->scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, scale); } else { Expr scale = ExpandBiasToMatchAxis( - sdata->scale, weight_layout.ndim(), {big_ic_axis}); + sdata->scale, kernel_layout.ndim(), {big_ic_axis}); weight = Multiply(weight, scale); } // return transformed conv2d @@ -799,11 +799,8 @@ RELAY_REGISTER_OP("multiply") AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { const auto* param = call->attrs.as(); CHECK(param != nullptr); - Layout out_layout(param->out_layout); - if (!out_layout.defined()) { - out_layout = Layout(param->data_layout); - } - Layout weight_layout(param->weight_layout); + Layout kernel_layout(param->kernel_layout); + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.Indexof('C'); int c_small_axis = out_layout.Indexof('c'); @@ -815,9 +812,9 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast - bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); - if (weight_layout.Indexof('o') < 0 && - weight_layout.Indexof('i') < 0 && + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); + if (kernel_layout.Indexof('o') < 0 && + kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { return {c_big_axis}; @@ -836,23 +833,20 @@ Expr Conv2DBackwardTransform(const Call& call, } const auto* param = call->attrs.as(); CHECK(param != nullptr); - Layout out_layout(param->out_layout); - if (!out_layout.defined()) { - out_layout = Layout(param->data_layout); - } - Layout weight_layout(param->weight_layout); + Layout kernel_layout(param->kernel_layout); + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.Indexof('C'); CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout - CHECK_EQ(weight_layout.Indexof('o'), -1); - CHECK_EQ(weight_layout.Indexof('i'), -1); + CHECK_EQ(kernel_layout.Indexof('o'), -1); + CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK(axes.size() == 1 && c_big_axis == axes[0]->value); - int big_oc_axis = weight_layout.Indexof('O'); + int big_oc_axis = kernel_layout.Indexof('O'); // Check it must be depthwise or full conv2d. - bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); Expr data = transformer->Transform( @@ -861,7 +855,7 @@ Expr Conv2DBackwardTransform(const Call& call, call->args[1], NullValue(), NullValue()); // scale on input for deptwise. Expr wscale = ExpandBiasToMatchAxis( - scale, weight_layout.ndim(), {big_oc_axis}); + scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, wscale); return CallNode::make( call->op, {data, weight}, call->attrs, call->type_args); diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index e6e8415bd620..24278c2fb236 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -112,11 +112,11 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, */ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, - const Layout& weight_layout) { + const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); auto wshape = ConvertLayout( call->args[1]->type_as()->shape, - weight_layout, kOIHW); + kernel_layout, kOIHW); return is_const_int(wshape[0], param->groups) && is_const_int(wshape[1], 1); } @@ -129,7 +129,7 @@ inline bool IsDepthwiseConv2D(const Call& call, inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { auto param = call->attrs.as(); auto tweight = call->args[1]->type_as(); - auto index = param->weight_layout.find('O'); + auto index = param->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); auto channels = as_const_int(tweight->shape[index]); return *channels; diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0544ee49d159..c8a38565ac7a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -41,7 +41,7 @@ def test_conv2d_infer_type(): padding=(1, 1), channels=16, data_layout="NCHW4n4c", - weight_layout="OIHW4o4i", + kernel_layout="OIHW4o4i", out_dtype="int32") yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType( diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 0fa1f1d692d5..48ab2ba271f7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -91,7 +91,7 @@ def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' - new_attrs['weight_layout'] = 'OIHW16i' + new_attrs['kernel_layout'] = 'OIHW16i' return relay.nn.conv2d(data, weight, **new_attrs) def expected(): @@ -105,7 +105,7 @@ def expected(): channels=64, kernel_size=(3, 3), padding=(1, 1), - weight_layout="OIHW16i", + kernel_layout="OIHW16i", data_layout="NCHW16c") b = relay.expand_dims(bias, axis=1, num_newaxis=2) b = relay.layout_transform(b, "CHW", "CHW16c") @@ -269,7 +269,7 @@ def before(): y = relay.Function(free_vars(y), y) return y - @register_alter_op_layout("nn.conv2d", level=102) + @register_alter_op_layout("nn.conv2d", level=105) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) @@ -305,6 +305,107 @@ def expected(): assert(alpha_equal(a, b)) +def test_alter_layout_scalar(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.add(y, relay.const(1, "float32")) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=106) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + w = relay.var("weight") + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.add(y, relay.const(1.0, "float32")) + + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + +def test_alter_layout_concatenate(): + """ """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + ret = relay.concatenate([y, y1], axis=1) + y = relay.Function(free_vars(ret), ret) + return y + + @register_alter_op_layout("nn.conv2d", level=107) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW16c') + ret = relay.concatenate([y, y1], axis=1) + ret = relay.layout_transform(ret, "NCHW16c", "NCHW") + y = relay.Function(free_vars(ret), ret) + return y + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) if __name__ == "__main__": test_alter_op() @@ -313,3 +414,5 @@ def expected(): test_alter_layout_dual_path() test_alter_layout_resnet() test_alter_layout_broadcast_op() + test_alter_layout_scalar() + test_alter_layout_concatenate() diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 57cb7c84b10d..7d0089cfb3c4 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -67,14 +67,14 @@ def before(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWIO", + kernel_layout="HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWIO", + kernel_layout="HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) @@ -90,7 +90,7 @@ def expected(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWIO", + kernel_layout="HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, @@ -98,7 +98,7 @@ def expected(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWIO", + kernel_layout="HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 22c9d2368de3..017c62b77a7b 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -523,9 +523,25 @@ def _callback(op): ##### REGISTER ALTER OP LAYOUT ##### @conv2d_alter_layout.register(["arm_cpu"]) -def _alter_conv2d_layout_arm(attrs, inputs, tinfos): - """Alter op layout for pre-computing kernel transformation""" - import nnvm.symbol as sym +def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): + """Alter op layout for pre-computing kernel transformation + + Parameters + ---------- + attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + Attributes of current convolution + inputs : nnvm.symbol or tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + F: symbol + The context, can be either nnvm.sym or relay.op + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level, + so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. + """ copy_inputs = [s for s in inputs] new_attrs = {k: attrs[k] for k in attrs.keys()} @@ -534,9 +550,11 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): strides = attrs.get_int_tuple("strides") padding = attrs.get_int_tuple("padding") groups = attrs.get_int('groups') - layout = attrs["layout"] + data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout" + layout = attrs[data_layout_key] out_dtype = attrs["out_dtype"] - out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype + if out_dtype == "" or out_dtype == "same": + out_dtype = tinfos[0].dtype if layout != 'NCHW' or groups != 1: return None @@ -570,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) dispatch_ctx.update(target, new_workload, cfg) - return sym.conv2d(*copy_inputs, **new_attrs) + return F.nn.conv2d(*copy_inputs, **new_attrs) else: # pre-compute weight transformation in winograd if "-device=arm_cpu" in target.options: tile_size = 4 @@ -580,10 +598,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): tile_size = _pick_tile_size(tinfos[0], tinfos[1]) VC = cfg['tile_bna'].val - weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) - weight = sym.reshape(weight, - shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) - weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3]) + weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) + weight = F.reshape(weight, + newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) + weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) copy_inputs[1] = weight new_attrs['tile_size'] = tile_size @@ -594,8 +612,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_weight, strides, padding, dilation, - new_attrs['layout'], out_dtype, tile_size], + new_attrs[data_layout_key], out_dtype, tile_size], conv2d_winograd_without_weight_transform) dispatch_ctx.update(target, new_workload, cfg) - return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index d32a87ba6b9d..0c2ea3db6f80 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -330,23 +330,40 @@ def _callback(op): ##### REGISTER ALTER OP LAYOUT ##### @nn.conv2d_alter_layout.register(["cuda", "gpu"]) -def _alter_conv2d_layout(attrs, inputs, tinfos): - """Alter op layout for pre-computing kernel transformation""" +def _alter_conv2d_layout(attrs, inputs, tinfos, F): + """Alter op layout for pre-computing kernel transformation + + Parameters + ---------- + attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + Attributes of current convolution + inputs : nnvm.symbol or tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + F: symbol + The context, can be either nnvm.sym or relay.op + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level, + so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. + """ if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs: return None - import nnvm.symbol as sym copy_inputs = [s for s in inputs] - new_attrs = {k: attrs[k] for k in attrs.keys()} strides = attrs.get_int_tuple("strides") padding = attrs.get_int_tuple("padding") dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int('groups') - layout = attrs["layout"] + data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout" + layout = attrs[data_layout_key] out_dtype = attrs["out_dtype"] - out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype + if out_dtype == "" or out_dtype == "same": + out_dtype = tinfos[0].dtype data, kernel = tinfos[0:2] N, CI, H, W = get_const_tuple(data.shape) @@ -371,7 +388,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): if cfg.template_key == 'int8': assert 'cuda' in target.keys new_layout = 'NCHW4c' - new_attrs['layout'] = new_layout + new_attrs[data_layout_key] = new_layout new_attrs['out_layout'] = new_layout new_attrs['kernel_layout'] = 'OIHW4o4i' ic_block_factor = oc_block_factor = 4 @@ -386,7 +403,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): conv2d ) dispatch_ctx.update(target, new_workload, cfg) - return sym.conv2d(*copy_inputs, **new_attrs) + return F.nn.conv2d(*copy_inputs, **new_attrs) if attrs.get_int_tuple("dilation") != (1, 1): warnings.warn("Does not support weight pre-transform for dilated convolution.") @@ -395,9 +412,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): # pre-compute weight transformation in winograd tile_size = _infer_tile_size(tinfos[0], tinfos[1]) - weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], - tile_size=tile_size) - weight = sym.transpose(weight, axes=[0, 1, 3, 2]) + weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], + tile_size=tile_size) + weight = F.transpose(weight, axes=[0, 1, 3, 2]) copy_inputs[1] = weight new_attrs['tile_size'] = tile_size @@ -410,7 +427,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): conv2d_winograd_without_weight_transform ) dispatch_ctx.update(target, new_workload, cfg) - return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) elif groups != CI: workload = autotvm.task.args_to_workload( [tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype], @@ -424,7 +441,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): if cfg.template_key == 'int8': assert 'cuda' in target.keys new_layout = 'NCHW4c' - new_attrs['layout'] = new_layout + new_attrs[data_layout_key] = new_layout new_attrs['out_layout'] = new_layout new_attrs['kernel_layout'] = 'OIHW4o4i' ic_block_factor = oc_block_factor = 4 @@ -440,7 +457,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): group_conv2d_nchw ) dispatch_ctx.update(target, new_workload, cfg) - return sym.conv2d(*copy_inputs, **new_attrs) + return F.nn.conv2d(*copy_inputs, **new_attrs) # do nothing for depthwise convolution return None diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index d712e71410d7..26679e74990c 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs +import warnings import tvm from .. import generic @@ -37,8 +38,13 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None return xi, thread_z, thread_y, thread_x @conv2d_alter_layout.register(["intel_graphics"]) -def _alter_conv2d_layout(attrs, inputs, tinfos): +def _alter_conv2d_layout(attrs, inputs, tinfos, F): import nnvm.symbol as sym + if F != sym: + warnings.warn("Only support alter layout for intel graphics in NNVM now. " + "This pass is ignored in relay.") + return None + copy_inputs = [s for s in inputs] data = tinfos[0] diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index d7b1f939ef45..4f2deb12c5cf 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -465,9 +465,9 @@ def _callback(op): ##### REGISTER ALTER OP LAYOUT ##### @conv2d_alter_layout.register(["mali"]) -def _alter_conv2d_layout(attrs, inputs, tinfos): +def _alter_conv2d_layout(attrs, inputs, tinfos, F): try: - return _alter_conv2d_layout_arm(attrs, inputs, tinfos) + return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F) except KeyError: # to filter out fallback opencl templates return None diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index a85d1268dbf8..977b80678524 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -57,17 +57,24 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N @tvm.target.generic_func -def conv2d_alter_layout(attrs, inputs, tinfos): +def conv2d_alter_layout(attrs, inputs, tinfos, F): """Change Conv2D layout. Parameters ---------- - attrs : nnvm.top.AttrDict + attrs : nnvm.top.AttrDict or tvm.attrs.Attrs Attributes of current convolution - inputs : nnvm.symbol + inputs : nnvm.symbol or tvm.relay.Expr Grouped input symbols tinfos : list Input shape and dtype + F: symbol + The context, can be either nnvm.sym or relay.op + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level, + so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. """ # not to change by default return None diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index fe38b38d38e0..02735f60e076 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -1,5 +1,7 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Conv2D schedule on x86""" +import warnings + import tvm from tvm import autotvm from tvm.autotvm.task.topi_integration import deserialize_args @@ -281,8 +283,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): @conv2d_alter_layout.register("cpu") -def _alter_conv2d_layout(attrs, inputs, tinfo): +def _alter_conv2d_layout(attrs, inputs, tinfo, F): import nnvm.symbol as sym + if F != sym: + warnings.warn("Only support alter layout for x86 in NNVM now. " + "This pass is ignored in relay.") + return None + copy_inputs = [s for s in inputs] new_attrs = {k : attrs[k] for k in attrs.keys()} data, kernel = tinfo[0], tinfo[1]