From 296a050a16c2f06b0f4a1edf8f6d59fd95c99e86 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Aug 2018 18:22:30 +0900 Subject: [PATCH 01/11] enhanced op fusion for elu --- nnvm/src/compiler/graph_fuse.cc | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index d4e668972593..6c56ad7d398c 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -161,6 +161,42 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { } } } + + if (opt_level >= 1) { + std::vector> children_group_ids(idx.num_nodes()); + std::vector> node_ids_per_group(idx.num_nodes()); + for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + CHECK_NE(group_vec[nid], -1); + node_ids_per_group[group_vec[nid]].push_back(nid); + if (inode.inputs.size() != 1) continue; + const auto& parent_nid = inode.inputs[0].node_id; + // if parent node has more than one child, record each child's group id. + if (ref_count[parent_nid] > 1) children_group_ids[parent_nid].push_back(group_vec[nid]); + } + std::vector new_group_id(idx.num_nodes(), -1); + for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) { + if (new_group_id[group_vec[nid]] != -1) { + // propagate new group id from child + group_vec[nid] = new_group_id[group_vec[nid]]; + } + const auto& group_ids = children_group_ids[nid]; + if (group_ids.size() <= 1) continue; + const auto child_group_id = group_ids[0]; + // fuse this node with children if all children belong to the same group + auto is_same_group_id = [child_group_id](uint32_t id) { return id == child_group_id; }; + if (std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id)) { + new_group_id[group_vec[nid]] = child_group_id; + group_vec[nid] = child_group_id; + for (uint32_t nid2 : node_ids_per_group[child_group_id]) { + pattern_vec[nid2] = pattern_vec[nid]; + master_vec[nid2] = master_vec[nid]; + } + } + } + } + g.attrs["group_root"] = std::make_shared(std::move(group_vec)); g.attrs["group_master"] = std::make_shared(std::move(master_vec)); g.attrs["pattern"] = std::make_shared(std::move(pattern_vec)); From dfbfd46fe1d9d94f746e638d25788c37457df490 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Aug 2018 20:05:27 +0900 Subject: [PATCH 02/11] don't fuse when op is opaque type --- nnvm/src/compiler/graph_fuse.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 6c56ad7d398c..39142cf14fd0 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -181,6 +181,8 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { // propagate new group id from child group_vec[nid] = new_group_id[group_vec[nid]]; } + TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque); + if (pt == kOpaque) continue; const auto& group_ids = children_group_ids[nid]; if (group_ids.size() <= 1) continue; const auto child_group_id = group_ids[0]; From b8662d9921b6739dd5c008ad44e2ec81c3726368 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Aug 2018 21:37:47 +0900 Subject: [PATCH 03/11] add elu test cast --- nnvm/tests/python/compiler/test_op_fusion.py | 50 ++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index f33e18197840..2683e6223540 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -77,7 +77,57 @@ def test_injective_reduce_injective(): np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) +def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2): + with nnvm.compiler.build_config(opt_level=opt_level): + graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params) + module = graph_runtime.create(graph, lib, ctx) + module.set_input(**params) + module.set_input("data", data) + module.run() + out = module.get_output(0, tvm.nd.empty(out_shape)) + return out.asnumpy(), graph + + +def test_fuse_conv2d_elu(): + def elu(data): + return -0.5 * sym.relu(1 - sym.exp(data)) + sym.relu(data) + + def get_sym(out_channel): + data = sym.Variable(name="data") + data = sym.conv2d(data=data, kernel_size=(3,3), channels=out_channel, padding=(1, 1), + layout="NCHW", kernel_layout="OIHW", use_bias=True) + data = elu(data) + return data + + in_channel = 8 + out_channel = 16 + size = 64 + dshape = (1, in_channel, size, size) + oshape = (1, out_channel, size, size) + + conv_weight = np.random.uniform(-1, 1, (out_channel, in_channel, 3, 3)).astype(np.float32) + conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32) + params = { + "conv2d0_weight" : tvm.nd.array(conv_weight, ctx=tvm.cpu(0)), + "conv2d0_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0)) + } + + params2 = { + "conv2d1_weight" : tvm.nd.array(conv_weight.copy(), ctx=tvm.cpu(0)), + "conv2d1_bias" : tvm.nd.array(conv_bias.copy(), ctx=tvm.cpu(0)) + } + + data = np.random.uniform(-1, 1, dshape).astype(np.float32) + sym1 = get_sym(out_channel) + sym2 = get_sym(out_channel) + + for target, ctx in ctx_list(): + output, g1 = build_and_run(sym1, params, data, oshape, target, ctx, opt_level=2) + output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) + np.testing.assert_allclose(output, output2, rtol=1e-5, atol=1e-5) + if __name__ == "__main__": test_injective_reduce_injective() test_ewise_injective() test_conv_ewise_injective() + test_fuse_conv2d_elu() From c658bccd47721916226719c59fe22abd0a084aea Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 4 Aug 2018 08:52:54 +0900 Subject: [PATCH 04/11] update test --- nnvm/tests/python/compiler/test_op_fusion.py | 26 ++++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 2683e6223540..22bd53708797 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -5,7 +5,7 @@ from tvm.contrib import graph_runtime from nnvm import symbol as sym from nnvm.compiler import graph_util, graph_attr -from nnvm.testing import ctx_list +from nnvm.testing import ctx_list, utils def test_ewise_injective(): x = sym.Variable("x") @@ -96,6 +96,7 @@ def get_sym(out_channel): data = sym.Variable(name="data") data = sym.conv2d(data=data, kernel_size=(3,3), channels=out_channel, padding=(1, 1), layout="NCHW", kernel_layout="OIHW", use_bias=True) + data = sym.batch_norm(data) data = elu(data) return data @@ -104,25 +105,18 @@ def get_sym(out_channel): size = 64 dshape = (1, in_channel, size, size) oshape = (1, out_channel, size, size) - - conv_weight = np.random.uniform(-1, 1, (out_channel, in_channel, 3, 3)).astype(np.float32) - conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32) - params = { - "conv2d0_weight" : tvm.nd.array(conv_weight, ctx=tvm.cpu(0)), - "conv2d0_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0)) - } - - params2 = { - "conv2d1_weight" : tvm.nd.array(conv_weight.copy(), ctx=tvm.cpu(0)), - "conv2d1_bias" : tvm.nd.array(conv_bias.copy(), ctx=tvm.cpu(0)) - } - - data = np.random.uniform(-1, 1, dshape).astype(np.float32) sym1 = get_sym(out_channel) sym2 = get_sym(out_channel) + _, params1 = utils.create_workload(sym1, 1, dshape[1:]) + _, params2 = utils.create_workload(sym2, 1, dshape[1:]) + for (p1, p2) in zip(params1.values(), params2.values()): + p2.copyfrom(p1) + + data = np.random.uniform(-1, 1, dshape).astype(np.float32) for target, ctx in ctx_list(): - output, g1 = build_and_run(sym1, params, data, oshape, target, ctx, opt_level=2) + print("Running on target", target) + output, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) np.testing.assert_allclose(output, output2, rtol=1e-5, atol=1e-5) From 4fe634f820a2c08028cfb2d9b44e092d6754db64 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 4 Aug 2018 09:30:08 +0900 Subject: [PATCH 05/11] update test --- nnvm/tests/python/compiler/test_op_fusion.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 22bd53708797..35af41abe1cd 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -105,16 +105,15 @@ def get_sym(out_channel): size = 64 dshape = (1, in_channel, size, size) oshape = (1, out_channel, size, size) - sym1 = get_sym(out_channel) - sym2 = get_sym(out_channel) - _, params1 = utils.create_workload(sym1, 1, dshape[1:]) - _, params2 = utils.create_workload(sym2, 1, dshape[1:]) - for (p1, p2) in zip(params1.values(), params2.values()): - p2.copyfrom(p1) - data = np.random.uniform(-1, 1, dshape).astype(np.float32) for target, ctx in ctx_list(): + sym1 = get_sym(out_channel) + sym2 = get_sym(out_channel) + _, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0) + _, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0) + print(params1.keys()) + print(params2.keys()) print("Running on target", target) output, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) From dd31822ced712b844aa3b6741a93b4f25afe74e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 4 Aug 2018 09:47:04 +0900 Subject: [PATCH 06/11] update test --- nnvm/tests/python/compiler/test_op_fusion.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 35af41abe1cd..0ccb124e4214 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -85,7 +85,7 @@ def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2): module.set_input("data", data) module.run() out = module.get_output(0, tvm.nd.empty(out_shape)) - return out.asnumpy(), graph + return out.asnumpy() def test_fuse_conv2d_elu(): @@ -112,12 +112,9 @@ def get_sym(out_channel): sym2 = get_sym(out_channel) _, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0) _, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0) - print(params1.keys()) - print(params2.keys()) - print("Running on target", target) - output, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) - output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) - np.testing.assert_allclose(output, output2, rtol=1e-5, atol=1e-5) + output1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) + output2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) + np.testing.assert_allclose(output1, output2, rtol=1e-5, atol=1e-5) if __name__ == "__main__": test_injective_reduce_injective() From 5763ccf77dc0afd2b82397b58937d476f3f3cae6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 5 Aug 2018 08:17:18 +0900 Subject: [PATCH 07/11] add explicit check for elemwise ops --- nnvm/src/compiler/graph_fuse.cc | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 39142cf14fd0..6ba77353350e 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -171,10 +171,11 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { CHECK_NE(group_vec[nid], -1); node_ids_per_group[group_vec[nid]].push_back(nid); if (inode.inputs.size() != 1) continue; - const auto& parent_nid = inode.inputs[0].node_id; + const uint32_t parent_nid = inode.inputs[0].node_id; // if parent node has more than one child, record each child's group id. if (ref_count[parent_nid] > 1) children_group_ids[parent_nid].push_back(group_vec[nid]); } + std::vector new_group_id(idx.num_nodes(), -1); for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) { if (new_group_id[group_vec[nid]] != -1) { @@ -185,10 +186,23 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { if (pt == kOpaque) continue; const auto& group_ids = children_group_ids[nid]; if (group_ids.size() <= 1) continue; - const auto child_group_id = group_ids[0]; - // fuse this node with children if all children belong to the same group - auto is_same_group_id = [child_group_id](uint32_t id) { return id == child_group_id; }; - if (std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id)) { + const uint32_t child_group_id = group_ids[0]; + const auto& children_node_ids = node_ids_per_group[child_group_id]; + + auto is_same_group_id = [child_group_id](uint32_t id) { + return id == child_group_id; + }; + auto is_fusible_pattern = [&idx](uint32_t child_nid) { + TOpPattern child_pt = op_pattern.get(idx[child_nid].source->op(), kOpaque); + return child_pt <= kBroadcast; + }; + // fuse this node with children if + // all children belong to the same group and + // all nodes in the group are element wise or broadcast op. + const bool can_be_fused = std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id) && + std::all_of(children_node_ids.begin(), children_node_ids.end(), is_fusible_pattern); + + if (can_be_fused) { new_group_id[group_vec[nid]] = child_group_id; group_vec[nid] = child_group_id; for (uint32_t nid2 : node_ids_per_group[child_group_id]) { From 874feda12ae6c5f9056fb94ca28437480e49f6c0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 5 Aug 2018 09:11:31 +0900 Subject: [PATCH 08/11] add doc for algorithm --- nnvm/src/compiler/graph_fuse.cc | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 6ba77353350e..868601b2d96d 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -162,6 +162,51 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { } } + /* + Above algorithm will not fuse a node whose output is fed to more than one + child node. This is because in general, it does not make sense to fuse multiple + children branches with their parent, as in the following example. + + conv2d + / | \ + / | \ + op op op + | | | + | | | + + However, when all children branches meet at a certain node, there is a possibility for + further operator fusion. For example, all nodes in the following subgraph can be fused + into a single node, if three 'in-between' nodes and the bottom node are all element wise + operation. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + This pattern is not uncommon. For example, it arises when conv2d op is followed by exponential + linear unit. If bias add and batch normalization are also present, they can be fused as well. + + In fact, above fusion algorithm already fuses three in-between nodes and the element wise + add node in the figure above. The following code fuses the conv2d node with the already + fused children nodes. The following patterns are supported. + + * Any number of child nodes from the top node + * The path from the top node to bottom node can contain any number of element wise ops. + + The only restriction is that in-between nodes cannot have more than one child. + + The overview of the algorithm below is as follows: + + 1. Check if all children nodes are fused into a single op by the existing fusion algorithm + 2. Fuse the parent node to children nodes, and update its group id to be the children's group id + 3. If the parent node originally belongs to another group (for example, conv + batch norm), + propagate the new group id to a grand parent and upward + */ if (opt_level >= 1) { std::vector> children_group_ids(idx.num_nodes()); std::vector> node_ids_per_group(idx.num_nodes()); From 9af8646011468c707b988d16f6fd8d665812bef2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 5 Aug 2018 11:56:07 +0900 Subject: [PATCH 09/11] fix for old compilers --- nnvm/src/compiler/graph_fuse.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 868601b2d96d..247224aa9bfe 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -208,8 +208,8 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { propagate the new group id to a grand parent and upward */ if (opt_level >= 1) { - std::vector> children_group_ids(idx.num_nodes()); - std::vector> node_ids_per_group(idx.num_nodes()); + std::vector > children_group_ids(idx.num_nodes()); + std::vector > node_ids_per_group(idx.num_nodes()); for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; From 3a3da5e0cffdb83ac268ec7a6f1251378a22e032 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 5 Aug 2018 14:24:07 +0900 Subject: [PATCH 10/11] check number of nodes after fusion --- nnvm/tests/python/compiler/test_op_fusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 0ccb124e4214..8d05ae02c579 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -85,7 +85,7 @@ def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2): module.set_input("data", data) module.run() out = module.get_output(0, tvm.nd.empty(out_shape)) - return out.asnumpy() + return out.asnumpy(), graph def test_fuse_conv2d_elu(): @@ -112,9 +112,11 @@ def get_sym(out_channel): sym2 = get_sym(out_channel) _, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0) _, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0) - output1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) - output2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) + output1, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2) + output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0) np.testing.assert_allclose(output1, output2, rtol=1e-5, atol=1e-5) + # data, conv weight, bias, batch norm gamma, batch norm beta, conv op + assert g1.index.num_nodes == 6 if __name__ == "__main__": test_injective_reduce_injective() From e0dff617a16e421120cf0caf410658bf2b847ae7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Aug 2018 08:40:43 +0900 Subject: [PATCH 11/11] update traverse inline logic for arm cpu --- topi/python/topi/arm_cpu/conv2d.py | 5 +---- topi/python/topi/util.py | 26 +++++++++++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index e28c08cb93ee..f5dbec8e552b 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -39,11 +39,10 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype): def schedule_conv2d_nchw_arm_cpu(cfg, outs): """TOPI schedule callback""" s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] def _callback(op): # schedule conv2d - if 'spatial_conv_output' in op.tag and op not in scheduled_ops: + if 'spatial_conv_output' in op.tag: output = op.output(0) conv = op.input_tensors[0] @@ -65,8 +64,6 @@ def _callback(op): output = op.output(0) _schedule_winograd(cfg, s, output, outs[0]) - scheduled_ops.append(op) - traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index b5d5dd2b99ad..71e123e83475 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -5,25 +5,33 @@ from . import tag -def traverse_inline(s, op, callback): +def traverse_inline(s, final_op, callback): """Traverse computation graph and do auto inline Parameters ---------- s: schedule The schedule - op: Operation + final_op: Operation The final output operator. callback: callable The callback function on each op """ - if tag.is_injective(op.tag): - if op not in s.outputs: - s[op].compute_inline() - for tensor in op.input_tensors: - if tensor.op.input_tensors: - traverse_inline(s, tensor.op, callback) - callback(op) + visited = set() + + def _traverse(op): + if op in visited: + return + visited.add(op) + if tag.is_injective(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + _traverse(tensor.op) + callback(op) + + _traverse(final_op) def prod(x):