|
38 | 38 | #include <unordered_map> |
39 | 39 |
|
40 | 40 | #include "alter_op_layout.h" |
| 41 | +#include "pattern_util.h" |
41 | 42 |
|
42 | 43 | namespace tvm { |
43 | 44 | namespace relay { |
44 | 45 |
|
45 | 46 | namespace alter_op_layout { |
46 | 47 |
|
47 | 48 | // Make a transform CallNode |
| 49 | +/* Performs 2 operations |
| 50 | + * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size. |
| 51 | + * For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC. |
| 52 | + * 2) Call layout transform with new src layout. |
| 53 | + */ |
48 | 54 | Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { |
49 | | - if (src_layout.Equals(dst_layout)) { return raw; } |
50 | | - CHECK(src_layout.defined() && dst_layout.defined()) |
51 | | - << "Cannot insert layout transform because there are undefined layouts"; |
52 | | - CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined()) |
53 | | - << "Cannot insert layout transform because there are inconvertible layouts: " |
54 | | - << src_layout << " v.s. " << dst_layout; |
55 | | - static auto &transform_op = Op::Get("layout_transform"); |
56 | | - NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>(); |
57 | | - attrs->src_layout = src_layout.name(); |
58 | | - attrs->dst_layout = dst_layout.name(); |
59 | | - Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs}); |
60 | | - return std::move(transform); |
| 55 | + if (src_layout.Equals(dst_layout)) { |
| 56 | + return raw; |
| 57 | + } |
| 58 | + |
| 59 | + // 1) Check if the shape lengths are different. If yes, expand dims. |
| 60 | + Expr input_expr = raw; |
| 61 | + Layout new_src_layout = src_layout; |
| 62 | + if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { |
| 63 | + int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); |
| 64 | + std::string src_primal_layout = src_layout.get_primal_axes(); |
| 65 | + std::string dst_primal_layout = dst_layout.get_primal_axes(); |
| 66 | + std::string new_src_layout_str = ""; |
| 67 | + for (auto s : dst_primal_layout) { |
| 68 | + if (src_primal_layout.find(s) == std::string::npos) { |
| 69 | + new_src_layout_str += s; |
| 70 | + } |
| 71 | + } |
| 72 | + new_src_layout_str += src_primal_layout; |
| 73 | + new_src_layout = Layout(new_src_layout_str); |
| 74 | + input_expr = MakeExpandDims(input_expr, 0, num_new_axis); |
| 75 | + } |
| 76 | + |
| 77 | + // 2) Insert layout transform on the transformed src. |
| 78 | + CHECK(new_src_layout.defined() && dst_layout.defined()) |
| 79 | + << "Cannot insert layout transform because there are undefined layouts"; |
| 80 | + CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined()) |
| 81 | + << "Cannot insert layout transform because there are inconvertible layouts: " |
| 82 | + << new_src_layout << " v.s. " << dst_layout; |
| 83 | + return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); |
61 | 84 | } |
62 | 85 |
|
63 | 86 | // Memorize layout transform so we can reuse internal transformed nodes |
|
0 commit comments