Skip to content

Commit 6185291

Browse files
committed
[Relay][AlterOp] Improving support for broadcast layout alteration.
1 parent 0cd8047 commit 6185291

File tree

7 files changed

+92
-33
lines changed

7 files changed

+92
-33
lines changed

include/tvm/data_layout.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,20 @@ class Layout : public NodeRef {
210210
return ct;
211211
}
212212

213+
/*! \return Concatenation of all primal axes */
214+
inline std::string get_primal_axes() const {
215+
std::string primal_axis = "";
216+
if (!defined()) {
217+
return primal_axis;
218+
}
219+
for (auto x : operator->()->axes) {
220+
if (LayoutAxis::Get(x).IsPrimal()) {
221+
primal_axis += LayoutAxis::Get(x).name();
222+
}
223+
}
224+
return primal_axis;
225+
}
226+
213227
/*!
214228
* \brief return the index of the input axis.
215229
* If it is not found in the layout or the layout is undefined,

src/relay/op/tensor/transform.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "../op_common.h"
3838
#include "../../../arithmetic/compute_expr.h"
3939
#include "../../pass/alter_op_layout.h"
40+
#include "../../pass/pattern_util.h"
4041
#include "transform.h"
4142

4243
namespace tvm {

src/relay/pass/alter_op_layout.cc

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,49 @@
3838
#include <unordered_map>
3939

4040
#include "alter_op_layout.h"
41+
#include "pattern_util.h"
4142

4243
namespace tvm {
4344
namespace relay {
4445

4546
namespace alter_op_layout {
4647

4748
// Make a transform CallNode
49+
/* Performs 2 operations
50+
* 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size.
51+
* For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
52+
* 2) Call layout transform with new src layout.
53+
*/
4854
Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
49-
if (src_layout.Equals(dst_layout)) { return raw; }
50-
CHECK(src_layout.defined() && dst_layout.defined())
51-
<< "Cannot insert layout transform because there are undefined layouts";
52-
CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
53-
<< "Cannot insert layout transform because there are inconvertible layouts: "
54-
<< src_layout << " v.s. " << dst_layout;
55-
static auto &transform_op = Op::Get("layout_transform");
56-
NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>();
57-
attrs->src_layout = src_layout.name();
58-
attrs->dst_layout = dst_layout.name();
59-
Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs});
60-
return std::move(transform);
55+
if (src_layout.Equals(dst_layout)) {
56+
return raw;
57+
}
58+
59+
// 1) Check if the shape lengths are different. If yes, expand dims.
60+
Expr input_expr = raw;
61+
Layout new_src_layout = src_layout;
62+
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
63+
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
64+
std::string src_primal_layout = src_layout.get_primal_axes();
65+
std::string dst_primal_layout = dst_layout.get_primal_axes();
66+
std::string new_src_layout_str = "";
67+
for (auto s : dst_primal_layout) {
68+
if (src_primal_layout.find(s) == std::string::npos) {
69+
new_src_layout_str += s;
70+
}
71+
}
72+
new_src_layout_str += src_primal_layout;
73+
new_src_layout = Layout(new_src_layout_str);
74+
input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
75+
}
76+
77+
// 2) Insert layout transform on the transformed src.
78+
CHECK(new_src_layout.defined() && dst_layout.defined())
79+
<< "Cannot insert layout transform because there are undefined layouts";
80+
CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined())
81+
<< "Cannot insert layout transform because there are inconvertible layouts: "
82+
<< new_src_layout << " v.s. " << dst_layout;
83+
return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
6184
}
6285

6386
// Memorize layout transform so we can reuse internal transformed nodes

src/relay/pass/alter_op_layout.h

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,27 +111,12 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
111111
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
112112
return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
113113
} else {
114-
// try to broadcast the tensors to the larger dimension
114+
// Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
115+
// while transforming layout.
115116
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
116117
int small_idx = 1 - large_idx;
117118
Layout ret = layouts[large_idx];
118-
119-
// extract common part
120-
size_t i = layouts[large_idx].ndim();
121-
for (; i != 0; --i) {
122-
const auto& axis = layouts[large_idx][i-1];
123-
if (!layouts[small_idx].Contains(axis.ToPrimal())) {
124-
break;
125-
}
126-
}
127-
128-
Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
129-
if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
130-
// not convertible
131-
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
132-
}
133-
134-
layouts.Set(small_idx, common_part);
119+
layouts.Set(small_idx, ret);
135120
return Array<Array<Layout> > {layouts, {ret}};
136121
}
137122
}

src/relay/pass/pattern_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);
505505

506506
Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
507507

508+
Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);
509+
508510
Expr StopFusion(Expr data);
509511

510512
Expr CastHint(Expr data, DataType dtype);

tests/python/relay/test_op_qnn_conv2d.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,36 @@ def tflite_anistropic_strides():
608608
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
609609
np.testing.assert_equal(qnn_output, golden_output)
610610

611+
def broadcast_layout_test():
612+
# Test broadcast support for NHWC layout.
613+
data_shape = (1, 229, 229, 3) # NHWC
614+
data_dtype = 'uint8'
615+
kernel_shape = (7, 7, 3, 64) # HWIO
616+
kernel_dtype = 'int8'
617+
_, qnn_func = get_funcs(data_shape=data_shape,
618+
data_dtype=data_dtype,
619+
kernel_shape=kernel_shape,
620+
kernel_dtype=kernel_dtype,
621+
input_zero_point=8,
622+
kernel_zero_point=3,
623+
kernel_size=(7, 7),
624+
padding=(1, 1),
625+
strides=(1, 1),
626+
dilation=(1, 1),
627+
data_layout="NHWC",
628+
kernel_layout="HWIO",
629+
out_dtype="int32")
630+
func = qnn_func['main'].body
631+
bias = relay.var("bias", shape=(64,), dtype="int32")
632+
633+
# Check broadcast support on both lhs and rhs
634+
func = relay.add(func, bias)
635+
func = relay.add(bias, func)
636+
func = relay.Function(relay.analysis.free_vars(func), func)
637+
mod = relay.Module.from_expr(func)
638+
with relay.build_config(opt_level=3):
639+
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
640+
611641
if __name__ == "__main__":
612642
no_zero_point_test()
613643
input_zero_point_test()
@@ -621,3 +651,4 @@ def tflite_anistropic_strides():
621651
tflite_large_irregular_test()
622652
tflite_output_multiplier_greater_than_one()
623653
tflite_anistropic_strides()
654+
broadcast_layout_test()

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def expected():
134134
kernel_layout="OIHW16i",
135135
data_layout="NCHW16c")
136136
b = relay.expand_dims(bias, axis=1, num_newaxis=2)
137-
b = relay.layout_transform(b, "CHW", "CHW16c")
137+
b = relay.expand_dims(b, axis=0, num_newaxis=1)
138+
b = relay.layout_transform(b, "NCHW", "NCHW16c")
138139
y = relay.add(y, b)
139140

140141
y = relay.nn.relu(y)
@@ -304,8 +305,10 @@ def expected():
304305
weight = relay.var("weight")
305306
x = relay.layout_transform(x, "NCHW", "NCHW16c")
306307
bias = relay.expand_dims(bias, 1, 2)
307-
bias = relay.layout_transform(bias, "CHW", "CHW16c")
308-
scale = relay.layout_transform(scale, "CHW", "CHW16c")
308+
bias = relay.expand_dims(bias, 0, 1)
309+
bias = relay.layout_transform(bias, "NCHW", "NCHW16c")
310+
scale = relay.expand_dims(scale, 0, 1)
311+
scale = relay.layout_transform(scale, "NCHW", "NCHW16c")
309312
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
310313
data_layout="NCHW16c")
311314
y = relay.add(y, bias) # test broadcasting to lhs

0 commit comments

Comments
 (0)