From c3bae77e7329ec71ce03eb613517f5fe2eb4b589 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 13 Mar 2021 00:54:31 +0000 Subject: [PATCH 1/5] [Relay][Pass] Simplify consecutive transpose/layout_transform --- src/relay/op/make_op.h | 2 + src/relay/transforms/simplify_expr.cc | 92 +++++++++++++++++++ tests/python/relay/test_pass_simplify_expr.py | 34 +++++++ 3 files changed, 128 insertions(+) diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 79f7e135e29d..36a5ec1c0e72 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -75,6 +75,8 @@ Expr MakeSqueeze(Expr data, Array axis); Expr MakeStack(Expr data, int axis); +Expr MakeTranspose(Expr data, Array axes); + Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, String slice_mode); diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 74e48dc4bc54..da651a7ced1d 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -82,6 +82,97 @@ class SimplifyReshape : public SimplifyPattern { DFPattern x_; }; +/*! + * \brief SimplifyTranspose matches the pattern of consecutive transpose op, + * and merges or cancels them. + */ +class SimplifyTranspose : public SimplifyPattern { + public: + SimplifyTranspose() { + x_ = IsWildcard(); + auto trans1 = IsOp("transpose") || IsOp("layout_transform"); + auto trans2 = IsOp("transpose") || IsOp("layout_transform"); + pattern_ = trans1({trans2({x_})}); + } + + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + // Helper function to get the axes from call node attribute + auto get_axes_from_call = [](const Call trans_call, size_t ndim) { + std::vector attr_axes; + if (auto attr = trans_call->attrs.as()) { + if (attr->axes.defined()) { + for (size_t i = 0; i < ndim; ++i) { + attr_axes.push_back(attr->axes[i]); + } + } else { + // Empty axes means reverse + for (size_t i = ndim - 1; i >= 0; --i) { + attr_axes.push_back(i); + } + } + } else if (auto attr = trans_call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + for (size_t i = 0; i < ndim; ++i) { + attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + } else { + CHECK(false) << "Expected transpose or layout_transform, but got " + << Downcast(trans_call->op)->name; + } + return std::move(attr_axes); + }; + + auto x = node_map[x_][0]; + + // Initialize axes + auto ndim = Downcast(pre->checked_type())->shape.size(); + Array axes; + for (size_t i = 0; i < ndim; ++i) { + axes.push_back(i); + } + + // Collect axes changes from the matched pattern, including two consecutive transposes. + std::vector> interm_axes; + Call trans_call = Downcast(post); + interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + trans_call = Downcast(trans_call->args[0]); + interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + + // Calculate the final axes in reverse order (from root to output) + auto it = interm_axes.rbegin(); + while (it != interm_axes.rend()) { + auto interm = *it; + + Array new_axes; + for (size_t i = 0; i < ndim; ++i) { + new_axes.push_back(axes[interm[i]]); + } + axes = new_axes; + it++; + } + + // Check if the transpose is still required + bool need_transpose = false; + for (int i = 0; i < static_cast(ndim); ++i) { + if (axes[i] != i) { + need_transpose = true; + break; + } + } + + if (need_transpose) { + return MakeTranspose(x, axes); + } + return x; + } + + private: + /*! \brief Pattern input */ + DFPattern x_; +}; + /*! * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op */ @@ -162,6 +253,7 @@ class ExprSimplifier { public: explicit ExprSimplifier(IRModule mod) : mod_(mod) { CreateCallback(SimplifyReshape()); + CreateCallback(SimplifyTranspose()); CreateCallback(FullElementwise()); } template diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 9531d896b2ed..329d74f1e535 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -17,6 +17,7 @@ import tvm from tvm import relay from tvm.relay import transform +from tvm.relay.op.transform import transpose from tvm.relay.testing import run_opt_pass import numpy as np @@ -60,6 +61,38 @@ def symbolic(): assert tvm.ir.structural_equal(zz, after) +def test_simplify_transpose(): + def before1(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC + y = relay.layout_transform(y, "NHWC", "HWCN") # To HWCN + y = relay.transpose(y, axes=[3, 0, 1, 2]) # To NHWC + return relay.Function([x], y) + + def expected1(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC + return relay.Function([x], y) + + def before2(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.nn.relu(x) + y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC + y = relay.transpose(y, axes=[1, 2, 3, 0]) # To HWCN + y = relay.transpose(y, axes=[3, 2, 0, 1]) # To NCHW + return relay.Function([x], y) + + def expected2(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.nn.relu(x) + return relay.Function([x], y) + + for before, expected in [[before1(), expected1()], [before2(), expected2()]]: + after = run_opt_pass(before, transform.SimplifyExpr()) + expected = run_opt_pass(expected, transform.InferType()) + assert tvm.ir.structural_equal(after, expected) + + def test_simplify_full_elementwise(): def validate(shape, value, dtype): def before_left(x, elem_op, full): @@ -126,4 +159,5 @@ def after_right(x, elem_op, value): if __name__ == "__main__": test_simplify_reshape() + test_simplify_transpose() test_simplify_full_elementwise() From 78b1230b1fec9e82cdc5e6062c840fec57ec4568 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 13 Mar 2021 00:55:41 +0000 Subject: [PATCH 2/5] lint --- tests/python/relay/test_pass_simplify_expr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 329d74f1e535..00f280488391 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -17,7 +17,6 @@ import tvm from tvm import relay from tvm.relay import transform -from tvm.relay.op.transform import transpose from tvm.relay.testing import run_opt_pass import numpy as np From a39e1c9aebc26757ecd684efd7b2ed13319e609c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 13 Mar 2021 01:02:50 +0000 Subject: [PATCH 3/5] fix --- src/relay/transforms/simplify_expr.cc | 16 ++++++++-------- tests/python/relay/test_pass_simplify_expr.py | 18 +++++++++++++++++- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index da651a7ced1d..03f20e5fe39c 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -98,23 +98,23 @@ class SimplifyTranspose : public SimplifyPattern { Expr callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { // Helper function to get the axes from call node attribute - auto get_axes_from_call = [](const Call trans_call, size_t ndim) { + auto get_axes_from_call = [](const Call trans_call, int ndim) { std::vector attr_axes; if (auto attr = trans_call->attrs.as()) { if (attr->axes.defined()) { - for (size_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { attr_axes.push_back(attr->axes[i]); } } else { // Empty axes means reverse - for (size_t i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0; --i) { attr_axes.push_back(i); } } } else if (auto attr = trans_call->attrs.as()) { Layout src_layout(attr->src_layout); Layout dst_layout(attr->dst_layout); - for (size_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); } } else { @@ -127,9 +127,9 @@ class SimplifyTranspose : public SimplifyPattern { auto x = node_map[x_][0]; // Initialize axes - auto ndim = Downcast(pre->checked_type())->shape.size(); + int ndim = Downcast(pre->checked_type())->shape.size(); Array axes; - for (size_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { axes.push_back(i); } @@ -146,7 +146,7 @@ class SimplifyTranspose : public SimplifyPattern { auto interm = *it; Array new_axes; - for (size_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { new_axes.push_back(axes[interm[i]]); } axes = new_axes; @@ -155,7 +155,7 @@ class SimplifyTranspose : public SimplifyPattern { // Check if the transpose is still required bool need_transpose = false; - for (int i = 0; i < static_cast(ndim); ++i) { + for (int i = 0; i < ndim; ++i) { if (axes[i] != i) { need_transpose = true; break; diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 00f280488391..9556730a6512 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -86,7 +86,23 @@ def expected2(): y = relay.nn.relu(x) return relay.Function([x], y) - for before, expected in [[before1(), expected1()], [before2(), expected2()]]: + def before3(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.nn.relu(x) + y = relay.transpose(y) # Reverse + y = relay.transpose(y) # Reverse + return relay.Function([x], y) + + def expected3(): + x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW + y = relay.nn.relu(x) + return relay.Function([x], y) + + for before, expected in [ + [before1(), expected1()], + [before2(), expected2()], + [before3(), expected3()], + ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(after, expected) From 0b57a1cd392522fb7144d181b12f2d941279e11c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 15 Mar 2021 17:32:42 +0000 Subject: [PATCH 4/5] support negative --- src/relay/transforms/simplify_expr.cc | 4 +++- tests/python/relay/test_pass_simplify_expr.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 03f20e5fe39c..3c8876ceccb5 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -103,7 +103,9 @@ class SimplifyTranspose : public SimplifyPattern { if (auto attr = trans_call->attrs.as()) { if (attr->axes.defined()) { for (int i = 0; i < ndim; ++i) { - attr_axes.push_back(attr->axes[i]); + int64_t axis = attr->axes[i]; + axis += (axis < 0) ? ndim : 0; + attr_axes.push_back(axis); } } else { // Empty axes means reverse diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 9556730a6512..91cf6d7da980 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -61,6 +61,7 @@ def symbolic(): def test_simplify_transpose(): + # Test a series of transpose and layout_transform ops def before1(): x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC @@ -73,6 +74,7 @@ def expected1(): y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC return relay.Function([x], y) + # Test that all transpose ops can be cancelled def before2(): x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW y = relay.nn.relu(x) @@ -86,16 +88,21 @@ def expected2(): y = relay.nn.relu(x) return relay.Function([x], y) + # Test default axis (reverse) and negative axis def before3(): x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW y = relay.nn.relu(x) y = relay.transpose(y) # Reverse y = relay.transpose(y) # Reverse + y = relay.transpose(y, axes=[0, 2, -1, 1]) + y = relay.transpose(y) # Reverse + y = relay.transpose(y) # Reverse return relay.Function([x], y) def expected3(): x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW y = relay.nn.relu(x) + y = relay.transpose(y, axes=[0, 2, 3, 1]) return relay.Function([x], y) for before, expected in [ From 9f877c2758ae53517d609381565ae76250b6ba3f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 15 Mar 2021 17:33:23 +0000 Subject: [PATCH 5/5] comment --- tests/python/relay/test_pass_simplify_expr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 91cf6d7da980..897f90b9ee2a 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -112,7 +112,9 @@ def expected3(): ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(after, expected) + assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format( + after, expected + ) def test_simplify_full_elementwise():