diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 4fa8aca4f3a9..b92feb59e4dc 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -131,13 +131,11 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, uint32_t indim = old_in_shapes[0].size(); auto r_axes = GetReduceAxes(indim, params->axis, params->exclude); - Layout ret = Layout::Undef(); - if (new_in_layouts.defined() && r_axes.size()) { - // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the - // modified layout axes. - ICHECK_EQ(new_in_layouts.size(), 1); - ICHECK_EQ(old_in_layouts.size(), 1); + Layout inferred_in = Layout::Undef(); + Layout inferred_out = Layout::Undef(); + // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout + auto infer = [&](const Layout& layout) { // 1) Collect the original axes std::unordered_set old_r_dims; for (auto r_axis : r_axes) { @@ -146,9 +144,10 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, // 2) Collect the new axes by walking new_layout. tvm::Array new_r_axes; - std::string new_layout_string = ""; + std::string inferred_in_string = ""; + std::string inferred_out_string = ""; int axis_index = 0; - for (auto iter_var : new_in_layouts[0]->axes) { + for (auto iter_var : layout->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); if (old_r_dims.count(layout_dim)) { @@ -156,21 +155,40 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, } // Collect only the primal axis. if (layout_axis.IsPrimal()) { - new_layout_string += layout_dim; + if (!old_r_dims.count(layout_dim) || params->keepdims) { + inferred_out_string += layout_dim; + } + inferred_in_string += layout_dim; axis_index++; } } // 3) Set the new axis and layout. - ret = Layout(new_layout_string); + return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes); + }; + + std::string new_layout_string; + Array new_r_axes; + + if (new_in_layouts.defined() && r_axes.size()) { + // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the + // modified layout axes. + ICHECK_EQ(new_in_layouts.size(), 1); + ICHECK_EQ(old_in_layouts.size(), 1); + + // Get inferred_in and inferred_out from new_in_layout. + std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]); params->axis = new_r_axes; } else if (old_in_layouts.defined()) { - // If the new layout is undefined, set the old layout as the inferred layout. ICHECK_EQ(old_in_layouts.size(), 1); - ret = old_in_layouts[0]; + + // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout. + if (old_in_layouts[0].defined()) { + std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]); + } } - return Array>{{ret}, {ret}}; + return Array>{{inferred_in}, {inferred_out}}; } template diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6fb9f77f99ea..11e94cb4b93e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2159,6 +2159,69 @@ Array SqueezeCompute(const Attrs& attrs, const Array& in return {topi::squeeze(inputs[0], param->axis)}; } +Array> SqueezeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + SqueezeAttrs* params = const_cast(attrs.as()); + + Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] : old_in_layouts[0]; + Layout inferred_output = inferred_input; + + ICHECK(old_in_types[0].as()); + const auto& shape = old_in_types[0].as()->shape; + + // axis to squeeze + Array axis; + if (params->axis.defined()) { + axis = params->axis; + } else { + // if axes is None, squeeze all axes of dimension 1 + for (size_t i = 0; i < shape.size(); i++) { + if (topi::detail::GetConstInt(shape[i]) == 1) { + axis.push_back(i); + } + } + } + + // If new_in_layouts are defined, this code tries to modify the layout + if (new_in_layouts.defined() && old_in_layouts.defined()) { + Array new_axis; + for (const auto& e : axis) { + const auto& dim = old_in_layouts[0][e]; + new_axis.push_back((new_in_layouts[0]).IndexOf(dim)); + } + params->axis = new_axis; + axis = new_axis; + } + + // Infer output layout + Array kept_axes; + for (size_t i = 0; i < inferred_input.ndim(); i++) { + bool is_dim_kept = true; + + // Check whether the dim should be kept + for (const auto& e : axis) { + int64_t axis_val = e->value; + if (axis_val < 0) { + axis_val += inferred_input.ndim(); + } + if (static_cast(i) == axis_val) { + is_dim_kept = false; + break; + } + } + + if (is_dim_kept) { + kept_axes.push_back(inferred_input->axes[i]); + } + } + inferred_output = Layout(kept_axes); + + return Array>{{inferred_input}, {inferred_output}}; +} + RELAY_REGISTER_OP("squeeze") .describe(R"code(Squeeze the input tensor at the dimensions given by axes @@ -2171,7 +2234,8 @@ RELAY_REGISTER_OP("squeeze") .set_support_level(3) .add_type_rel("Squeeze", SqueezeRel) .set_attr("FTVMCompute", SqueezeCompute) - .set_attr("TOpPattern", kInjective); + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", SqueezeInferCorrectLayout); // CollapseSumLike: -> B where BroadCast(A, B) = A bool CollapseSumLikeRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index ba443f602c19..1293be70273e 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -91,9 +91,17 @@ class ConvertTransformMemorizer : public TransformMemorizer { auto desired_layouts = operator->()->desired_layouts_; if (desired_layouts.find(op->name) != desired_layouts.end()) { tvm::Array tinfos; - for (auto expr : ref_call->args) { - auto ttype = expr->type_as(); - tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); + for (auto& expr : ref_call->args) { + if (expr->checked_type()->IsInstance()) { + auto tuple_ttype_node = expr->type_as(); + for (auto& ttype : tuple_ttype_node->fields) { + auto ttype_node = ttype.as(); + tinfos.push_back(tvm::te::placeholder(ttype_node->shape, ttype_node->dtype)); + } + } else { + auto ttype = expr->type_as(); + tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); + } } Array op_desired_layouts = desired_layouts.at(op->name); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index ca2469ea0a4c..7eccc4a82c70 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1556,6 +1556,191 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_squeeze_convert_layout(): + def _test_conv_squeeze_convert_layout1(): + # specified axis is squeezed + def before(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + y = relay.nn.conv2d( + x, + weight, + channels=1000, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.squeeze(y, axis=[-3]) + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1)) + y = relay.nn.relu(y) + y = relay.squeeze(y, axis=[2]) + y = relay.layout_transform(y, "NCW", "NWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_conv_squeeze_convert_layout2(): + # all axes of dimension 1 are squeezed + def before(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + y = relay.nn.conv2d( + x, + weight, + channels=1000, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.squeeze(y) + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1)) + y = relay.nn.relu(y) + y = relay.squeeze(y, [0, 2, 3]) + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_conv_squeeze_convert_layout3(): + # squeeze axis is empty + def before(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + y = relay.nn.conv2d( + x, + weight, + channels=1000, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.squeeze(y, axis=[]) + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1)) + y = relay.nn.relu(y) + y = relay.squeeze(y, axis=[]) + y = relay.layout_transform(y, "NCHW", "NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + _test_conv_squeeze_convert_layout1() + _test_conv_squeeze_convert_layout2() + _test_conv_squeeze_convert_layout3() + + +def test_conv_reduce_convert_layout(): + def _test_conv_reduce_convert_layout1(): + def before(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + y = relay.nn.conv2d( + x, + weight, + channels=1000, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.sum(y, axis=(1, 2)) + y = relay.sum(y, axis=(1,)) + y = relay.sum(y) + y = relay.sum(y) + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 1, 1, 2048)) + weight = relay.var("weight", shape=(1, 1, 2048, 1000)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1)) + y = relay.nn.relu(y) + y = relay.sum(y, axis=(2, 3)) + y = relay.sum(y, axis=(1,)) + y = relay.sum(y) + y = relay.sum(y) + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_conv_reduce_convert_layout2(): + def before(): + x = relay.var("x", shape=(1, 38, 38, 512)) + weight = relay.var("weight", shape=(3, 3, 512, 512)) + y = relay.nn.conv2d( + x, + weight, + channels=512, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.multiply(y, y) + y = relay.sum(y, axis=(3,), keepdims=True) + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 38, 38, 512)) + weight = relay.var("weight", shape=(3, 3, 512, 512)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight, channels=512, kernel_size=(3, 3)) + y = relay.nn.relu(y) + y = relay.multiply(y, y) + y = relay.sum(y, axis=(1,), keepdims=True) + y = relay.layout_transform(y, "NCHW", "NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + _test_conv_reduce_convert_layout1() + _test_conv_reduce_convert_layout2() + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1584,3 +1769,5 @@ def expected(): test_different_ops_convert_layout() test_no_desired_layout() test_convert_with_config() + test_conv_squeeze_convert_layout() + test_conv_reduce_convert_layout()