From 7790cd4dcc55ab394cc46f2781c52a0737f27116 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 7 Apr 2021 11:48:07 -0700 Subject: [PATCH 1/5] [Relay][Pass] Simplify consecutive transpose/layout_transform ops when layout transformation changes rank. --- src/relay/transforms/simplify_expr.cc | 101 ++++++++++++------ src/relay/transforms/type_infer.cc | 1 - tests/python/relay/test_pass_simplify_expr.py | 18 ++++ 3 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 5662ef5b45a6..c04ccb41e792 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -118,36 +118,15 @@ class SimplifyTranspose : public DFPatternRewrite { Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { // Helper function to get the axes from call node attribute - auto get_axes_from_call = [](const Call trans_call, int ndim) { - std::vector attr_axes; - if (auto attr = trans_call->attrs.as()) { - if (attr->axes.defined()) { - for (int i = 0; i < ndim; ++i) { - int64_t axis = attr->axes[i]; - axis += (axis < 0) ? ndim : 0; - attr_axes.push_back(axis); - } - } else { - // Empty axes means reverse - for (int i = ndim - 1; i >= 0; --i) { - attr_axes.push_back(i); - } - } - } else if (auto attr = trans_call->attrs.as()) { - Layout src_layout(attr->src_layout); - Layout dst_layout(attr->dst_layout); - for (int i = 0; i < ndim; ++i) { - attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); - } - } else { - CHECK(false) << "Expected transpose or layout_transform, but got " - << Downcast(trans_call->op)->name; - } - return std::move(attr_axes); - }; auto x = node_map[x_][0]; + Call trans_call = Downcast(post); + + if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) { + return layout_trans.value(); + } + // Initialize axes int ndim = Downcast(pre->checked_type())->shape.size(); Array axes; @@ -157,10 +136,9 @@ class SimplifyTranspose : public DFPatternRewrite { // Collect axes changes from the matched pattern, including two consecutive transposes. std::vector> interm_axes; - Call trans_call = Downcast(post); - interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); trans_call = Downcast(trans_call->args[0]); - interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); // Calculate the final axes in reverse order (from root to output) auto it = interm_axes.rbegin(); @@ -190,6 +168,69 @@ class SimplifyTranspose : public DFPatternRewrite { return x; } + String PermuteLayout(const String& layout, std::vector axes) const { + std::string new_layout{}; + std::string old_layout{layout}; + for (auto axis : axes) { + new_layout += old_layout[axis]; + } + return String(new_layout); + } + + Optional FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { + Optional layout_trans; + if (auto attr = call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + if (src_layout->axes.size() != dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(Downcast(call->args[0]), src_layout->axes.size()); + std::vector inverse(axes.size()); + for (size_t i = 0; i < axes.size(); i++) { + inverse[axes[i]] = i; + } + String new_layout = PermuteLayout(attr->src_layout, inverse); + layout_trans = MakeLayoutTransform(data, new_layout, dst_layout->name); + } + } else if (auto attr = Downcast(call->args[0])->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + if (src_layout->axes.size() != dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(call, dst_layout->axes.size()); + String new_layout = PermuteLayout(attr->dst_layout, axes); + layout_trans = MakeLayoutTransform(data, src_layout->name, new_layout); + } + } + return layout_trans; + } + + std::vector GetTransposeAxisOrder(const Call& call, int ndim) const { + std::vector attr_axes; + if (auto attr = call->attrs.as()) { + if (attr->axes.defined()) { + for (int i = 0; i < ndim; ++i) { + int64_t axis = attr->axes[i]; + axis += (axis < 0) ? ndim : 0; + attr_axes.push_back(axis); + } + } else { + // Empty axes means reverse + for (int i = ndim - 1; i >= 0; --i) { + attr_axes.push_back(i); + } + } + } else if (auto attr = call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + for (int i = 0; i < ndim; ++i) { + attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + } else { + CHECK(false) << "Expected transpose or layout_transform, but got " + << Downcast(call->op)->name; + } + return std::move(attr_axes); + } + private: /*! \brief Pattern input */ DFPattern x_; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4c6013792426..595a8cf545d4 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -741,7 +741,6 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Expr TypeInferencer::Infer(GlobalVar var, Function function) { // Set the current function being type checked. this->current_func_ = var; - // Step 1: Populate the constraints. GetType(function); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index d1dffa34578b..b9e7207dc9ab 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -106,10 +106,28 @@ def expected3(): y = relay.transpose(y, axes=[0, 2, 3, 1]) return relay.Function([x], y) + # Test a series of transpose and rank changing layout_transform + def before4(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.transpose(x, axes=[0, 3, 1, 2]) # To NCHW + y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW + y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC + return relay.Function([x], y) + + def expected4(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC + return relay.Function([x], y) + for before, expected in [ [before1(), expected1()], [before2(), expected2()], [before3(), expected3()], + [before4(), expected4()], ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) From d098f1a4f32039e22b85c0d346969632fb4731be Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 20 Apr 2021 15:37:20 -0700 Subject: [PATCH 2/5] Refactor to support more combinations of rank changing layout transforms. --- src/relay/transforms/simplify_expr.cc | 99 +++++++++++++++---- src/relay/transforms/type_infer.cc | 1 + tests/python/relay/test_pass_simplify_expr.py | 84 ++++++++++++++++ 3 files changed, 163 insertions(+), 21 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index c04ccb41e792..e45746c8b431 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -33,6 +33,8 @@ #include #include +#include + #include "../op/tensor/transform.h" #include "pattern_utils.h" @@ -117,13 +119,17 @@ class SimplifyTranspose : public DFPatternRewrite { Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { - // Helper function to get the axes from call node attribute - auto x = node_map[x_][0]; Call trans_call = Downcast(post); if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) { + if (auto attr = layout_trans.value()->attrs.as()) { + // Prune any trivial layout transformation + if (attr->src_layout == attr->dst_layout) { + return x; + } + } return layout_trans.value(); } @@ -177,30 +183,81 @@ class SimplifyTranspose : public DFPatternRewrite { return String(new_layout); } - Optional FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { - Optional layout_trans; + struct RankChangingLayoutDescriptor { + Layout src_layout; + Layout dst_layout; + // Either a rank changing layout transform or a transpose + Call other_transform; + }; + + std::unique_ptr GetRankChangeDescriptor(const Call& call) const { + std::unique_ptr desc{nullptr}; if (auto attr = call->attrs.as()) { - Layout src_layout(attr->src_layout); - Layout dst_layout(attr->dst_layout); - if (src_layout->axes.size() != dst_layout->axes.size()) { - auto axes = GetTransposeAxisOrder(Downcast(call->args[0]), src_layout->axes.size()); - std::vector inverse(axes.size()); - for (size_t i = 0; i < axes.size(); i++) { - inverse[axes[i]] = i; + if (attr->src_layout.length() != attr->dst_layout.length()) { + desc = std::make_unique(); + desc->src_layout = Layout(attr->src_layout); + desc->dst_layout = Layout(attr->dst_layout); + desc->other_transform = Downcast(call->args[0]); + } + } + if (auto attr = Downcast(call->args[0])->attrs.as()) { + if (attr->src_layout.length() != attr->dst_layout.length()) { + if (!desc) { + desc = std::make_unique(); + desc->src_layout = Layout(attr->src_layout); + desc->dst_layout = Layout(attr->dst_layout); + desc->other_transform = call; + } else { + ICHECK(desc->src_layout->name == attr->dst_layout) + << "Back-to-back layout transforms must have the same intermediate layout: " + << desc->src_layout->name << " != " << attr->dst_layout; + desc->src_layout = Layout(attr->src_layout); } - String new_layout = PermuteLayout(attr->src_layout, inverse); - layout_trans = MakeLayoutTransform(data, new_layout, dst_layout->name); } - } else if (auto attr = Downcast(call->args[0])->attrs.as()) { - Layout src_layout(attr->src_layout); - Layout dst_layout(attr->dst_layout); - if (src_layout->axes.size() != dst_layout->axes.size()) { - auto axes = GetTransposeAxisOrder(call, dst_layout->axes.size()); - String new_layout = PermuteLayout(attr->dst_layout, axes); - layout_trans = MakeLayoutTransform(data, src_layout->name, new_layout); + } + return desc; + } + + /* + * \brief Fuse call and it's argument into a single layout_transform operator + * when either call or it's argument is a rang changing layout_transform, e.g., + * + * Simplify + * + * [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c] + * + * to, + * + * [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c]. + * + * \param The input expression to the matched pattern + * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops + */ + Optional FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { + auto desc = GetRankChangeDescriptor(call); + if (desc == nullptr) { + return Optional{nullptr}; + } + + Optional output_layout_trans; + if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size()); + std::vector inverse(axes.size()); + for (size_t i = 0; i < axes.size(); i++) { + inverse[axes[i]] = i; } + String new_layout = PermuteLayout(std::string(desc->src_layout->name), inverse); + output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name); + } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size()); + String new_layout = PermuteLayout(std::string(desc->dst_layout->name), axes); + output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout); + } else if (desc->other_transform->attrs.as()) { + // Fuse two consecutive layout transforms + output_layout_trans = + MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name); } - return layout_trans; + return Downcast(output_layout_trans); } std::vector GetTransposeAxisOrder(const Call& call, int ndim) const { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 595a8cf545d4..4c6013792426 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -741,6 +741,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Expr TypeInferencer::Infer(GlobalVar var, Function function) { // Set the current function being type checked. this->current_func_ = var; + // Step 1: Populate the constraints. GetType(function); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index b9e7207dc9ab..b7d54c83fc24 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -123,11 +123,95 @@ def expected4(): y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC return relay.Function([x], y) + def before5(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW + y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW + y = relay.layout_transform(y, "NCHW", "NHWC") # To NHWC + return relay.Function([x], y) + + def expected5(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC + return relay.Function([x], y) + + def before6(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected6(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before7(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + y = relay.layout_transform(x, "NCHW4c", "NCHW8c") + y = relay.layout_transform(y, "NCHW8c", "NCHW4c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected7(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before8(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + y = relay.layout_transform(x, "NCHW4c", "NCHW") + y = relay.layout_transform(y, "NCHW", "NCHW8c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected8(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + y = relay.layout_transform(x, "NCHW4c", "NCHW8c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def before9(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NCHW", "NCHW4c") # To NCHW4c + y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected9(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before10(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") # NHWC + y = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.layout_transform(y, "NHWC", "CHWN") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected10(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") # NCHW + y = relay.transpose(x, axes=[1, 2, 3, 0]) # To CHWN + y = relay.nn.relu(y) + return relay.Function([x], y) + for before, expected in [ [before1(), expected1()], [before2(), expected2()], [before3(), expected3()], [before4(), expected4()], + [before5(), expected5()], + [before6(), expected6()], + [before7(), expected7()], + [before8(), expected8()], + [before9(), expected9()], + [before10(), expected10()], ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) From 336e4a3a1b8720089515978ed4f68f5ebfe544be Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 20 Apr 2021 17:06:06 -0700 Subject: [PATCH 3/5] Add comments, logging and validation per CRs. --- src/relay/transforms/simplify_expr.cc | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index e45746c8b431..76f9328f5d65 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -123,6 +123,7 @@ class SimplifyTranspose : public DFPatternRewrite { Call trans_call = Downcast(post); + // Try to fuse any rank changing layout transformations if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) { if (auto attr = layout_trans.value()->attrs.as()) { // Prune any trivial layout transformation @@ -174,13 +175,20 @@ class SimplifyTranspose : public DFPatternRewrite { return x; } - String PermuteLayout(const String& layout, std::vector axes) const { + String PermuteLayout(const String& layout, std::vector axes_order) const { std::string new_layout{}; std::string old_layout{layout}; - for (auto axis : axes) { + ICHECK_EQ(axes_order.size(), layout.size()) + << "Number of axes must match the number of named axes in the layout to permute: length(" + << old_layout << ") != " << axes_order.size(); + std::stringstream order; + for (auto axis : axes_order) { new_layout += old_layout[axis]; + order << axis << ", "; } - return String(new_layout); + DLOG(INFO) << "Using transpose axes order {" << order.str() + << "} to permute layout: " << old_layout << " to " << new_layout; + return new_layout; } struct RankChangingLayoutDescriptor { @@ -234,26 +242,33 @@ class SimplifyTranspose : public DFPatternRewrite { * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops */ Optional FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { + // Check to see if either the first or second call in matched pattern + // is a rank changing layout transform. If so, return a descriptor containing + // the layouts and any additional transpose or layout transform op. auto desc = GetRankChangeDescriptor(call); if (desc == nullptr) { + // No rank changing layout transform return Optional{nullptr}; } Optional output_layout_trans; + // Fuse a rank increasing layout transform and a preceeding transpose if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) { auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size()); + // Calculate the reverse axis order and apply to the source layout std::vector inverse(axes.size()); for (size_t i = 0; i < axes.size(); i++) { inverse[axes[i]] = i; } - String new_layout = PermuteLayout(std::string(desc->src_layout->name), inverse); + String new_layout = PermuteLayout(desc->src_layout->name, inverse); output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name); + // Fuse a rank descreasing layout transform followed by a transpose } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) { auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size()); - String new_layout = PermuteLayout(std::string(desc->dst_layout->name), axes); + String new_layout = PermuteLayout(desc->dst_layout->name, axes); output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout); + // Fuse two back-to-back layout transformations which change rank } else if (desc->other_transform->attrs.as()) { - // Fuse two consecutive layout transforms output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name); } From fff7311b526ae05ea0ad30420f068a929f6ba29b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 3 May 2021 14:42:06 -0700 Subject: [PATCH 4/5] Add tests comments. --- tests/python/relay/test_pass_simplify_expr.py | 100 ++++++++++++++---- 1 file changed, 82 insertions(+), 18 deletions(-) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index b7d54c83fc24..6e2f0fc99239 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -108,12 +108,21 @@ def expected3(): # Test a series of transpose and rank changing layout_transform def before4(): - x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC - y = relay.transpose(x, axes=[0, 3, 1, 2]) # To NCHW - y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c + ''' + Simplify transpose->layout_transform and its inverse. + + Input: + NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC + + Simplified: + NHWC -> NCHW4c -> op -> NCHW4c -> NHWC + ''' + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") + y = relay.transpose(x, axes=[0, 3, 1, 2]) + y = relay.layout_transform(y, "NCHW", "NCHW4c") y = relay.nn.relu(y) - y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW - y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC + y = relay.layout_transform(y, "NCHW4c", "NCHW") + y = relay.transpose(y, axes=[0, 2, 3, 1]) return relay.Function([x], y) def expected4(): @@ -124,6 +133,15 @@ def expected4(): return relay.Function([x], y) def before5(): + ''' + Simplify layout_transform->layout_transform and its inverse. + + Input: + NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC + + Simplified: + NHWC -> NCHW4c -> op -> NCHW4c -> NHWC + ''' x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c @@ -140,64 +158,110 @@ def expected5(): return relay.Function([x], y) def before6(): - x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + ''' + Remove trivial layout_transform->layout_transform. + + Input: + NCHW -> NHWC -> NCHW -> op + + Simplified: + NHWC -> op + ''' + + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.layout_transform(x, "NCHW", "NHWC") y = relay.layout_transform(y, "NHWC", "NCHW") y = relay.nn.relu(y) return relay.Function([x], y) def expected6(): - x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.nn.relu(x) return relay.Function([x], y) def before7(): - x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + ''' + Remove trivial layout_transform->layout_transform. + + Input: + NCHW4c -> NCHW8c -> NCHW4c -> op + + Simplified: + NCHW4c -> op + ''' + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.layout_transform(x, "NCHW4c", "NCHW8c") y = relay.layout_transform(y, "NCHW8c", "NCHW4c") y = relay.nn.relu(y) return relay.Function([x], y) def expected7(): - x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.nn.relu(x) return relay.Function([x], y) def before8(): - x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + ''' + Simplify layout_transform->layout_transform with rank contraction and expansion + + Input: + NCHW4c -> NCHW -> NCHW8c -> op + + Simplified: + NCHW4c -> NCHW8c -> op + ''' + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.layout_transform(x, "NCHW4c", "NCHW") y = relay.layout_transform(y, "NCHW", "NCHW8c") y = relay.nn.relu(y) return relay.Function([x], y) def expected8(): - x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") # NCHW4c + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.layout_transform(x, "NCHW4c", "NCHW8c") y = relay.nn.relu(y) return relay.Function([x], y) def before9(): - x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC - y = relay.layout_transform(x, "NCHW", "NCHW4c") # To NCHW4c - y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW + ''' + Remove trivial layout_transform->layout_transform. + + Input: + NCHW -> NCHW4c -> NCHW -> op + + Simplified: + NCHW -> op + ''' + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.layout_transform(x, "NCHW", "NCHW4c") + y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.nn.relu(y) return relay.Function([x], y) def expected9(): - x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.nn.relu(x) return relay.Function([x], y) def before10(): - x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") # NHWC + ''' + Simplify layout_transform->layout_transform without rank change to transpose. + + Input: + NCHW -> NHWC -> CHWN -> op + + Simplified: + NCHW -> CHWN -> op + ''' + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.layout_transform(x, "NCHW", "NHWC") y = relay.layout_transform(y, "NHWC", "CHWN") y = relay.nn.relu(y) return relay.Function([x], y) def expected10(): - x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") # NCHW - y = relay.transpose(x, axes=[1, 2, 3, 0]) # To CHWN + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.transpose(x, axes=[1, 2, 3, 0]) y = relay.nn.relu(y) return relay.Function([x], y) From 667e07c57d2915cbdfd3fc80eede2bd9583a2e27 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 3 May 2021 15:40:16 -0700 Subject: [PATCH 5/5] Apply formatting. --- src/relay/transforms/simplify_expr.cc | 4 +-- tests/python/relay/test_pass_simplify_expr.py | 28 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 76f9328f5d65..fb7a76f1ea7a 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -31,9 +31,9 @@ #include #include -#include - #include +#include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 6e2f0fc99239..9f11d3827064 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -108,7 +108,7 @@ def expected3(): # Test a series of transpose and rank changing layout_transform def before4(): - ''' + """ Simplify transpose->layout_transform and its inverse. Input: @@ -116,7 +116,7 @@ def before4(): Simplified: NHWC -> NCHW4c -> op -> NCHW4c -> NHWC - ''' + """ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") y = relay.transpose(x, axes=[0, 3, 1, 2]) y = relay.layout_transform(y, "NCHW", "NCHW4c") @@ -133,7 +133,7 @@ def expected4(): return relay.Function([x], y) def before5(): - ''' + """ Simplify layout_transform->layout_transform and its inverse. Input: @@ -141,7 +141,7 @@ def before5(): Simplified: NHWC -> NCHW4c -> op -> NCHW4c -> NHWC - ''' + """ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c @@ -158,7 +158,7 @@ def expected5(): return relay.Function([x], y) def before6(): - ''' + """ Remove trivial layout_transform->layout_transform. Input: @@ -166,7 +166,7 @@ def before6(): Simplified: NHWC -> op - ''' + """ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.layout_transform(x, "NCHW", "NHWC") @@ -180,7 +180,7 @@ def expected6(): return relay.Function([x], y) def before7(): - ''' + """ Remove trivial layout_transform->layout_transform. Input: @@ -188,7 +188,7 @@ def before7(): Simplified: NCHW4c -> op - ''' + """ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.layout_transform(x, "NCHW4c", "NCHW8c") y = relay.layout_transform(y, "NCHW8c", "NCHW4c") @@ -201,7 +201,7 @@ def expected7(): return relay.Function([x], y) def before8(): - ''' + """ Simplify layout_transform->layout_transform with rank contraction and expansion Input: @@ -209,7 +209,7 @@ def before8(): Simplified: NCHW4c -> NCHW8c -> op - ''' + """ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") y = relay.layout_transform(x, "NCHW4c", "NCHW") y = relay.layout_transform(y, "NCHW", "NCHW8c") @@ -223,7 +223,7 @@ def expected8(): return relay.Function([x], y) def before9(): - ''' + """ Remove trivial layout_transform->layout_transform. Input: @@ -231,7 +231,7 @@ def before9(): Simplified: NCHW -> op - ''' + """ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.layout_transform(x, "NCHW", "NCHW4c") y = relay.layout_transform(y, "NCHW4c", "NCHW") @@ -244,7 +244,7 @@ def expected9(): return relay.Function([x], y) def before10(): - ''' + """ Simplify layout_transform->layout_transform without rank change to transpose. Input: @@ -252,7 +252,7 @@ def before10(): Simplified: NCHW -> CHWN -> op - ''' + """ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") y = relay.layout_transform(x, "NCHW", "NHWC") y = relay.layout_transform(y, "NHWC", "CHWN")