From 548b39a26024342131e5a78f11a11f1c5676be25 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Apr 2023 20:15:13 +0900 Subject: [PATCH] use split instead of slice in CombineParallelMatmul --- .../transform/combine_parallel_matmul.cc | 19 +-- .../test_transform_combine_parallel_matmul.py | 116 ++++++++++-------- 2 files changed, 76 insertions(+), 59 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index d6435ec8292f..a7f8711a1f19 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -176,18 +176,21 @@ runtime::TypedPackedFunc(Map)> GetRewriter( } } - PrimExpr begin{0}; - Array strides{1}; + int ind = 0; + Array sections; + for (int i = 0; i < static_cast(indices.size()) - 1; ++i) { + auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as(); + ind += width->value; + sections.push_back(IntImm(DataType::Int(64), ind)); + } + int lhs_dim = GetTensorSInfo(inp)->ndim; - int slice_axis = std::max(lhs_dim, rhs_dim) - 1; + int split_axis = std::max(lhs_dim, rhs_dim) - 1; + auto chunks = split(matmul_combined, sections, split_axis); for (size_t i = 0; i < indices.size(); ++i) { - auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1]; auto bound_var = matchings[pattern_to_replace[indices[i]]]; - auto slice = - strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); - replacements.Set(bound_var, slice); - begin += width; + replacements.Set(bound_var, TupleGetItem(chunks, i)); } } diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index f5cc269620f7..41cba1a58bac 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -89,10 +89,11 @@ def expected1( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) - lv1_1 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) - lv2 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) - lv3 = R.concat((lv_1, lv1_1, lv2), axis=1) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=1) + lv_1 = lv2[0] + lv1_1 = lv2[1] + lv2_1 = lv2[2] + lv3 = R.concat((lv_1, lv1_1, lv2_1), axis=1) R.output(lv3) return lv3 @@ -112,10 +113,11 @@ def expected2( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) - lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) - lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], strides=[1]) - lv3 = R.concat((lv_1, lv1_1, lv2), axis=1) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=2) + lv_1 = lv2[0] + lv1_1 = lv2[1] + lv2_1 = lv2[2] + lv3 = R.concat((lv_1, lv1_1, lv2_1), axis=1) R.output(lv3) return lv3 @@ -141,9 +143,10 @@ def expected1( lv1 = R.matmul(x, lv, out_dtype="float32") lv2 = R.concat((bias, bias_1, bias_2), axis=0) lv3 = R.add(lv1, lv2) - lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], strides=[1]) - lv3_1 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], strides=[1]) - lv5 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv4 = R.split(lv3, indices_or_sections=[640, 1280], axis=1) + lv1_1 = lv4[0] + lv3_1 = lv4[1] + lv5 = lv4[2] lv6 = R.concat((lv1_1, lv3_1, lv5), axis=1) R.output(lv6) return lv6 @@ -165,12 +168,13 @@ def expected2( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=1) + lv_1 = lv2[0] lv1_1 = R.add(lv_1, bias) - lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) - lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv2_1 = lv2[1] + lv3 = lv2[2] lv4 = R.add(lv3, bias_1) - lv5 = R.concat((lv1_1, lv2, lv4), axis=1) + lv5 = R.concat((lv1_1, lv2_1, lv4), axis=1) R.output(lv5) return lv5 @@ -192,10 +196,11 @@ def expected1( lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") lv2 = R.nn.relu(lv1) - lv1_1 = R.strided_slice(lv2, axes=[1], begin=[0], end=[640], strides=[1]) - lv3 = R.strided_slice(lv2, axes=[1], begin=[640], end=[1280], strides=[1]) - lv5 = R.strided_slice(lv2, axes=[1], begin=[1280], end=[1920], strides=[1]) - lv6 = R.concat((lv1_1, lv3, lv5), axis=1) + lv3 = R.split(lv2, indices_or_sections=[640, 1280], axis=1) + lv1_1 = lv3[0] + lv3_1 = lv3[1] + lv5 = lv3[2] + lv6 = R.concat((lv1_1, lv3_1, lv5), axis=1) R.output(lv6) return lv6 @@ -214,11 +219,12 @@ def expected2( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=1) + lv_1 = lv2[0] lv1_1 = R.nn.gelu(lv_1) - lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) - lv3 = R.nn.relu(lv2) - lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv2_1 = lv2[1] + lv3 = R.nn.relu(lv2_1) + lv4 = lv2[2] lv5 = R.nn.relu(lv4) lv6 = R.concat((lv1_1, lv3, lv5), axis=1) R.output(lv6) @@ -239,11 +245,13 @@ def expected3( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=1) + + lv_1 = lv2[0] lv1_1 = R.nn.relu(lv_1) - lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) - lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) - lv4 = R.concat((lv1_1, lv2, lv3), axis=1) + lv2_1 = lv2[1] + lv3 = lv2[2] + lv4 = R.concat((lv1_1, lv2_1, lv3), axis=1) R.output(lv4) return lv4 @@ -270,10 +278,11 @@ def expected1( lv2 = R.concat((bias, bias_1, bias_2), axis=0) lv3 = R.add(lv1, lv2) lv4 = R.nn.relu(lv3) - lv2_1 = R.strided_slice(lv4, axes=[1], begin=[0], end=[640], strides=[1]) - lv5 = R.strided_slice(lv4, axes=[1], begin=[640], end=[1280], strides=[1]) - lv8 = R.strided_slice(lv4, axes=[1], begin=[1280], end=[1920], strides=[1]) - lv9 = R.concat((lv2_1, lv5, lv8), axis=1) + lv5 = R.split(lv4, indices_or_sections=[640, 1280], axis=1) + lv2_1 = lv5[0] + lv5_1 = lv5[1] + lv8 = lv5[2] + lv9 = R.concat((lv2_1, lv5_1, lv8), axis=1) R.output(lv9) return lv9 @@ -297,12 +306,13 @@ def expected2( lv1 = R.matmul(x, lv, out_dtype="float32") lv2 = R.concat((bias, bias_1, bias_2), axis=0) lv3 = R.add(lv1, lv2) - lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], strides=[1]) + lv4 = R.split(lv3, indices_or_sections=[640, 1280], axis=1) + lv1_1 = lv4[0] lv2_1 = R.nn.relu(lv1_1) - lv4 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], strides=[1]) - lv6 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv4_1 = lv4[1] + lv6 = lv4[2] lv7 = R.nn.relu(lv6) - lv8 = R.concat((lv2_1, lv4, lv7), axis=1) + lv8 = R.concat((lv2_1, lv4_1, lv7), axis=1) R.output(lv8) return lv8 @@ -323,14 +333,15 @@ def expected3( with R.dataflow(): lv = R.concat((y, y_1, y_2), axis=1) lv1 = R.matmul(x, lv, out_dtype="float32") - lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=1) + lv_1 = lv2[0] lv1_1 = R.add(lv_1, bias) - lv2 = R.nn.relu(lv1_1) - lv3 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) - lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv2_1 = R.nn.relu(lv1_1) + lv3 = lv2[1] + lv4 = lv2[2] lv5 = R.add(lv4, bias_1) lv6 = R.nn.relu(lv5) - lv7 = R.concat((lv2, lv3, lv6), axis=1) + lv7 = R.concat((lv2_1, lv3, lv6), axis=1) R.output(lv7) return lv7 @@ -370,11 +381,12 @@ def expected1( with R.dataflow(): lv = R.concat((w0, w2), axis=2) lv1 = R.matmul(x, lv, out_dtype="float32") - lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640], axis=2) + lv0 = lv2[0] lv1_1 = R.matmul(x, w1, out_dtype="void") - lv2 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) + lv2_1 = lv2[1] lv3 = R.matmul(x, w3, out_dtype="void") - out = lv0, lv1_1, lv2, lv3 + out = lv0, lv1_1, lv2_1, lv3 R.output(out) return out @@ -449,16 +461,18 @@ def expected1( with R.dataflow(): lv = R.concat((w0, w1, w2), axis=1) lv1 = R.matmul(x1, lv, out_dtype="float32") - lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) - lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) + lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=2) + lv0 = lv2[0] + lv1_1 = lv2[1] lv_1 = R.concat((w3, w4), axis=1) lv1_2 = R.matmul(x2, lv_1, out_dtype="float32") - lv2 = R.concat((b0, b1), axis=0) - lv3 = R.add(lv1_2, lv2) - lv5 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640], strides=[1]) - lv2_1 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], strides=[1]) - lv6 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280], strides=[1]) - out = lv0, lv1_1, lv2_1, lv5, lv6 + lv2_1 = R.concat((b0, b1), axis=0) + lv3 = R.add(lv1_2, lv2_1) + lv4 = R.split(lv3, indices_or_sections=[640], axis=2) + lv5 = lv4[0] + lv2_2 = lv2[2] + lv6 = lv4[1] + out = lv0, lv1_1, lv2_2, lv5, lv6 R.output(out) return out