From 1a8740b58c6ff6d0955f4efd07ed5b176f26f5ed Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:06:31 +0800 Subject: [PATCH] Add layout inference support for repeat operator --- src/relax/op/tensor/manipulate.cc | 60 ++++++++++++- .../relax/test_transform_convert_layout.py | 85 +++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 0310c7f46b0d..493198fbd091 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1805,12 +1805,70 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice); } -// TODO(relax-team): implement FRelaxInferLayout for repeat +InferLayoutOutput InferLayoutRepeat( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + int ndim = tensor_sinfo->ndim; + + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } + + // When axis is not specified, the output is 1D (flattened) + if (!attrs->axis.has_value()) { + return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(1)}, Attrs(call->attrs)); + } + + // Transform the axis based on the layout + int axis = attrs->axis.value(); + if (axis < 0) { + axis += ndim; + } + + // Create a mapping from original layout to existing layout + std::string axis_str(ndim, '0'); + axis_str[axis] = '1'; + for (int i = 0, j = 0; i < ndim; ++i) { + if (axis_str[i] != '1') { + axis_str[i] = 'A' + j++; + } + } + + ffi::String new_axis_str = + TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); + + int64_t new_axis = -1; + for (size_t i = 0; i < new_axis_str.size(); ++i) { + if (new_axis_str.at(i) == '1') { + new_axis = i; + break; + } + } + ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; + + ObjectPtr new_attrs = ffi::make_object(*attrs); + new_attrs->axis = new_axis; + + // When axis is specified, the layout is preserved + return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); +} + TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoRepeat) + .set_attr("FRelaxInferLayout", InferLayoutRepeat) .set_attr("FPurity", Bool(true)); /* relax.tile */ diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 83b81a6898a7..95f043ef6629 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -4992,5 +4992,90 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_repeat(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 8, 26, 26), "float32") = R.repeat(gv, repeats=2, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.repeat(gv, repeats=2, axis=3) + gv2: R.Tensor((2, 8, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_repeat_flatten(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor((5408,), "float32"): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((5408,), "float32") = R.repeat(gv, repeats=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((5408,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((5408,), dtype="float32") = R.repeat(gv, repeats=1) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main()