diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b8875e48ed0f..427254c6e9c9 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -33,6 +33,23 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); +Array > ResizeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + ResizeAttrs* params = const_cast(attrs.as()); + + if (new_in_layouts.defined() && new_in_layouts[0].defined()) { + // Set the resize with the new layout. + ICHECK_EQ(new_in_layouts.size(), 1); + params->layout = new_in_layouts[0].name(); + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; +} + bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); @@ -98,6 +115,7 @@ RELAY_REGISTER_OP("image.resize") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(Resize3dAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index ca2469ea0a4c..77885f8d558a 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -244,6 +244,43 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_resize_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.image.resize(y, size=(56, 56), layout="NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.image.resize(y, size=(56, 56), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_concat_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64))