diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index c9f14c91c7b1..5001925b7570 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -149,23 +149,41 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, tvm::Array new_r_axes; std::string inferred_in_string = ""; std::string inferred_out_string = ""; - int axis_index = 0; - for (auto iter_var : layout->axes) { - const auto& layout_axis = LayoutAxis::Get(iter_var); + auto push_new_axis = [&](const std::string& layout_dim, int axis) { + if ((old_r_dims.count(layout_dim) && !params->exclude) || + (!old_r_dims.count(layout_dim) && params->exclude)) { + new_r_axes.push_back(tvm::Integer(axis)); + return true; + } + return false; + }; + for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) { + const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]); const std::string& layout_dim = layout_axis.name(); - // Collect only the primal axis. if (layout_axis.IsPrimal()) { - if (old_r_dims.count(layout_dim) && !params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } - if (!old_r_dims.count(layout_dim) && params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } + push_new_axis(layout_dim, axis_index); + inferred_in_string += layout_dim; if (!old_r_dims.count(layout_dim) || params->keepdims) { inferred_out_string += layout_dim; } - inferred_in_string += layout_dim; - axis_index++; + } else { + // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original + // reduce axes is [1], the new reduce axes become [1, 4]. + auto primal_dim = layout_axis.ToPrimal().name(); + auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim; + inferred_in_string += packed_dim; + if (push_new_axis(primal_dim, axis_index)) { + if (params->exclude) { + // The primal axis is not reduced, so keep the input packed dim. + inferred_out_string += packed_dim; + } else { + // If the primal axis is part of reduce axes in the original layout, the inner dim + // becomes 1 after reduction. + inferred_out_string += "1" + layout_dim; + } + } else { + inferred_out_string += packed_dim; + } } } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 19685b127d86..ab36f79c6ea7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -507,12 +507,12 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW") bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) - y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - mean = relay.mean(y, axis=1, exclude=True) - var = relay.variance(y, axis=1, exclude=True) + mean = relay.mean(add, axis=[1, 4], exclude=True) + var = relay.variance(add, axis=[1, 4], exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) - denom = denom * gamma + denom_c16c = denom * relay.layout_transform(gamma, src_layout="C", dst_layout="C16c") + denom = relay.layout_transform(denom_c16c, src_layout="C16c", dst_layout="C") denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2) denom_expand2 = relay.expand_dims(denom_expand1, axis=0) denom_nchwc16 = relay.layout_transform( @@ -520,7 +520,10 @@ def expected(): ) out = add * denom_nchwc16 beta = relay.var("beta", shape=(16,)) - numerator = (-mean) * denom + beta + numerator_c16c = (-mean) * denom_c16c + relay.layout_transform( + beta, src_layout="C", dst_layout="C16c" + ) + numerator = relay.layout_transform(numerator_c16c, src_layout="C16c", dst_layout="C") numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2) numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0) numerator_nchwc16 = relay.layout_transform( @@ -1096,8 +1099,8 @@ def expected_nchw(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NCHW") y = relay.Function(analysis.free_vars(ret), ret) return y @@ -1126,9 +1129,8 @@ def expected_nhwc(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) - ret = relay.layout_transform(ret, "NCHW", "NHWC") + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NHWC") y = relay.Function(analysis.free_vars(ret), ret) return y