diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e8132df6f75a..baa03f49924d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -723,6 +723,8 @@ def nms_strategy_cuda(attrs, inputs, out_type, target): def roi_align_strategy_cuda(attrs, inputs, out_type, target): """roi_align cuda strategy""" strategy = _op.OpStrategy() + layout = attrs.layout + assert layout == "NCHW", "only support nchw for now" strategy.add_implementation( wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw), wrap_topi_schedule(topi.cuda.schedule_roi_align), diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 070efa4375ed..b0c97ecafe6a 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -948,6 +948,8 @@ def _compute_roi_align(attrs, inputs, out_type): def roi_align_strategy(attrs, inputs, out_type, target): """roi_align generic strategy""" strategy = _op.OpStrategy() + layout = attrs.layout + assert layout == "NCHW", "only support nchw for now" strategy.add_implementation( wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw), wrap_topi_schedule(topi.generic.schedule_roi_align), diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 4d7d7e80b7b7..81b0e8f65a9a 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -380,6 +380,8 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" strategy = _op.OpStrategy() + layout = attrs.layout + assert layout == "NCHW", "only support nchw for now" strategy.add_implementation( wrap_compute_roi_align(topi.x86.roi_align_nchw), wrap_topi_schedule(topi.generic.schedule_roi_align), diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py index d20cb97980e7..a5cc266f1566 100644 --- a/python/tvm/relay/op/vision/_rcnn.py +++ b/python/tvm/relay/op/vision/_rcnn.py @@ -26,6 +26,49 @@ reg.register_strategy("vision.roi_align", strategy.roi_align_strategy) reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_convert_op_layout("vision.roi_align") +def convert_roi_align(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for roi_align op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current roi_align + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and rois inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + data, rois = inputs + new_attrs = dict(attrs) + assert ( + len(desired_layouts) == 2 + ), "A desired layout is expected for both of vision.roi_align's inputs" + + desired_data_layout, desired_rois_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + assert desired_rois_layout == "default", "Rois layout must be default" + + new_attrs["layout"] = desired_data_layout + # rois layout not change + if desired_data_layout in ["NCHW", "NHWC"]: + return relay.vision.roi_align(data, rois, **new_attrs) + + raise ValueError("Layout %s is not yet supported." % desired_data_layout) + + # roi_pool @reg.register_compute("vision.roi_pool") def compute_roi_pool(attrs, inputs, _): diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index f7e1ecb82dcb..f14b29604f06 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -25,6 +25,8 @@ #include #include +#include "../../transforms/infer_layout_util.h" + namespace tvm { namespace relay { @@ -43,14 +45,36 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(roi_align_attrs); CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D."; CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D."; - CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout"; // assign output type - std::vector oshape( - {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]}); + std::vector oshape; + if (roi_align_attrs->layout == "NCHW") { + oshape = {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], + roi_align_attrs->pooled_size[1]}; + } else { + CHECK_EQ(roi_align_attrs->layout, "NHWC") << "Unexpected ROI Align layout"; + oshape = {rshape[0], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1], + dshape[3]}; + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } +template +Array > ROIAlignInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + Layout data_layout = params->layout; + + // Layout inference needs to define the layout for all inputs and output data layouts. + // For roi_align, the second inputs is 2-D tensor with shape [num_roi, 5]. + // So, we set the layout as "N5". + return Array >{{data_layout, Layout("N5")}, {data_layout}}; +} + Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, int sample_ratio, String layout) { auto attrs = make_object(); @@ -78,7 +102,9 @@ RELAY_REGISTER_OP("vision.roi_align") .add_argument("data", "Tensor", "The input tensor.") .add_argument("rois", "Tensor", "The input rois") .set_support_level(5) - .add_type_rel("ROIAlign", ROIAlignRel); + .add_type_rel("ROIAlign", ROIAlignRel) + .set_attr("FInferCorrectLayout", + ROIAlignInferCorrectLayout); TVM_REGISTER_NODE_TYPE(ROIPoolAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 86687eac6b67..9954c0143a68 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -841,6 +841,59 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_roi_align_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + rois = relay.var("rois", shape=(32, 5)) + y = relay.vision.roi_align( + y, rois, pooled_size=(14, 14), spatial_scale=0.0625, sample_ratio=2, layout="NCHW" + ) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight1 = relay.layout_transform(weight1, "OIHW", "HWIO") + y = relay.nn.conv2d( + x, + weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + rois = relay.var("rois", shape=(32, 5)) + y = relay.vision.roi_align( + y, rois, pooled_size=(14, 14), spatial_scale=0.0625, sample_ratio=2, layout="NHWC" + ) + ret = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + desired_layouts = { + "nn.conv2d": ["NHWC", "HWIO"], + "vision.roi_align": ["NHWC", "default"], + } + a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_default_keyword(): """ Check that the default keyword selects correct TVM default layout. """ @@ -1005,5 +1058,6 @@ def expected(): test_qnn_conv_nhwc_convert_layout() test_conv_convert_kernel_layout() test_conv_transpose_convert_layout() + test_conv_roi_align_convert_layout() test_default_keyword() test_different_ops_convert_layout()