From c87ffd9062bc07063e035e182933330278e499c9 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 8 Jul 2017 16:23:19 -0700 Subject: [PATCH 1/3] [PASS] Layout transform pass --- apps/graph_executor/src/graph_pass.cc | 109 ++++++++++++++++++++++++ apps/graph_executor/src/op_attr_types.h | 12 +++ 2 files changed, 121 insertions(+) diff --git a/apps/graph_executor/src/graph_pass.cc b/apps/graph_executor/src/graph_pass.cc index 5df34c6594cf..5c3fe2b9ed6b 100644 --- a/apps/graph_executor/src/graph_pass.cc +++ b/apps/graph_executor/src/graph_pass.cc @@ -387,5 +387,114 @@ NNVM_REGISTER_OP(tvm_op) const TVMOpParam& param = nnvm::get(attrs.parsed); return param.num_outputs; }); + + +inline bool IsPair(LayoutInfo in, LayoutInfo out) { + if (in.src == out.dst && in.dst == out.src) return true; + return false; +} + +inline LayoutInfo GetLayout(const nnvm::OpMap& layouts, + const nnvm::NodePtr& n, int idx) { + return layouts[n->op()](n->attrs)[idx]; +} + +nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) { + static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform"); + static int count = 0; + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = trans_op; + n->attrs.name = src + "_to_" + dst + std::to_string(count++); + n->attrs.dict["src"] = src; + n->attrs.dict["dst"] = dst; + return n; +} + + +nnvm::Graph LayoutTransform(nnvm::Graph src) { + static auto& ilayouts = + nnvm::Op::GetAttr("FTVMInputsLayoutInfo"); + static auto& olayouts = + nnvm::Op::GetAttr("FTVMOutputsLayoutInfo"); + + nnvm::NodeEntryMap transformed; + std::unordered_map mirror_map; + + DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { + nnvm::NodePtr new_node = nnvm::Node::Create(); + *new_node = *n; + if (new_node->is_variable()) { + mirror_map[n.get()] = new_node; + return; + } + + for (size_t idx = 0; idx < n->inputs.size(); ++idx) { + const nnvm::NodeEntry& e = n->inputs[idx]; + const nnvm::NodePtr& in = e.node; + new_node->inputs[idx] = e; + + bool otrans = olayouts.count(in->op()); + bool itrans = ilayouts.count(n->op()); + if (otrans && itrans) { + LayoutInfo olayout = GetLayout(olayouts, in, e.index); + LayoutInfo ilayout = GetLayout(ilayouts, n, idx); + if (IsPair(olayout, ilayout)) { + break; + } + } + + if (otrans) { + LayoutInfo layout = GetLayout(olayouts, in, e.index); + if (!transformed.count(e)) { + nnvm::NodePtr tnode = + CreateLayoutTransformNode(layout.src, layout.dst); + tnode->inputs.emplace_back(e); + transformed.emplace(e, nnvm::NodeEntry{tnode, 0, 0}); + } + new_node->inputs[idx] = transformed.at(e); + } + + if (itrans) { + LayoutInfo layout = GetLayout(ilayouts, n, idx); + nnvm::NodePtr tnode = + CreateLayoutTransformNode(layout.src, layout.dst); + tnode->inputs.emplace_back(new_node->inputs[idx]); + new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0}; + } + } + mirror_map[n.get()] = std::move(new_node); + }); + + + std::vector outputs; + for (const auto& e : src.outputs) { + nnvm::NodePtr mirror_node = mirror_map.at(e.node.get()); + nnvm::NodeEntry mirror_entry{mirror_node, e.index, e.version}; + + if (olayouts.count(e.node->op())) { + LayoutInfo layout = GetLayout(olayouts, e.node, e.index); + nnvm::NodePtr tnode = + CreateLayoutTransformNode(layout.src, layout.dst); + tnode->inputs.emplace_back(mirror_entry); + + outputs.emplace_back( + nnvm::NodeEntry{tnode, 0, 0}); + } else { + outputs.emplace_back( + nnvm::NodeEntry{mirror_node, e.index, e.version}); + } + } + + nnvm::Graph ret; + ret.outputs = std::move(outputs); + return ret; +} + +NNVM_REGISTER_PASS(LayoutTransform) +.set_body(LayoutTransform); + +NNVM_REGISTER_OP(layout_transform) +.set_num_inputs(1) +.set_num_outputs(1); } // namespace contrib } // namespace tvm diff --git a/apps/graph_executor/src/op_attr_types.h b/apps/graph_executor/src/op_attr_types.h index c7b4a55e5eba..6f8f07ef61ab 100644 --- a/apps/graph_executor/src/op_attr_types.h +++ b/apps/graph_executor/src/op_attr_types.h @@ -52,6 +52,18 @@ using FTVMSchedule = std::function< const Array& outs, const std::string& target)>; +// TODO better structure of layout +struct LayoutInfo { + using Layout = std::string; + Layout src; + Layout dst; +}; + +using FTVMLayoutInfo = std::function< + std::vector(const NodeAttrs& attrs)>; +using FTVMInputsLayoutInfo = FTVMLayoutInfo; +using FTVMOutputsLayoutInfo = FTVMLayoutInfo; + // The storage result of op enum OpPatternKind : int { // Elementwise operation From fe1d134bbffc431b0b9a7f1989ec47da6ece2550 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 9 Jul 2017 14:32:41 -0700 Subject: [PATCH 2/3] Fix according to comment --- apps/graph_executor/src/graph_pass.cc | 41 ++++++++++++------------- apps/graph_executor/src/op_attr_types.h | 20 +++++++++++- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/apps/graph_executor/src/graph_pass.cc b/apps/graph_executor/src/graph_pass.cc index 5c3fe2b9ed6b..1f6fbc20ef4a 100644 --- a/apps/graph_executor/src/graph_pass.cc +++ b/apps/graph_executor/src/graph_pass.cc @@ -411,6 +411,10 @@ nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) { } +/*! + * \brief A simple layout transform pass that will + * insert layout transform nodes automatically. + */ nnvm::Graph LayoutTransform(nnvm::Graph src) { static auto& ilayouts = nnvm::Op::GetAttr("FTVMInputsLayoutInfo"); @@ -428,10 +432,22 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { return; } + if (olayouts.count(n->op())) { + for (uint32_t i = 0; i < n->num_outputs(); ++i) { + LayoutInfo layout = GetLayout(olayouts, n, i); + nnvm::NodePtr tnode = + CreateLayoutTransformNode(layout.src, layout.dst); + tnode->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); + transformed.emplace( + nnvm::NodeEntry{n, i, 0}, nnvm::NodeEntry{tnode, 0, 0}); + } + } + for (size_t idx = 0; idx < n->inputs.size(); ++idx) { const nnvm::NodeEntry& e = n->inputs[idx]; const nnvm::NodePtr& in = e.node; - new_node->inputs[idx] = e; + new_node->inputs[idx] = + nnvm::NodeEntry{mirror_map.at(in.get()), e.index, e.version}; bool otrans = olayouts.count(in->op()); bool itrans = ilayouts.count(n->op()); @@ -439,18 +455,11 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { LayoutInfo olayout = GetLayout(olayouts, in, e.index); LayoutInfo ilayout = GetLayout(ilayouts, n, idx); if (IsPair(olayout, ilayout)) { - break; + continue; } } if (otrans) { - LayoutInfo layout = GetLayout(olayouts, in, e.index); - if (!transformed.count(e)) { - nnvm::NodePtr tnode = - CreateLayoutTransformNode(layout.src, layout.dst); - tnode->inputs.emplace_back(e); - transformed.emplace(e, nnvm::NodeEntry{tnode, 0, 0}); - } new_node->inputs[idx] = transformed.at(e); } @@ -465,23 +474,13 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { mirror_map[n.get()] = std::move(new_node); }); - std::vector outputs; for (const auto& e : src.outputs) { - nnvm::NodePtr mirror_node = mirror_map.at(e.node.get()); - nnvm::NodeEntry mirror_entry{mirror_node, e.index, e.version}; - if (olayouts.count(e.node->op())) { - LayoutInfo layout = GetLayout(olayouts, e.node, e.index); - nnvm::NodePtr tnode = - CreateLayoutTransformNode(layout.src, layout.dst); - tnode->inputs.emplace_back(mirror_entry); - - outputs.emplace_back( - nnvm::NodeEntry{tnode, 0, 0}); + outputs.emplace_back(transformed.at(e)); } else { outputs.emplace_back( - nnvm::NodeEntry{mirror_node, e.index, e.version}); + nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); } } diff --git a/apps/graph_executor/src/op_attr_types.h b/apps/graph_executor/src/op_attr_types.h index 6f8f07ef61ab..1012b268c98c 100644 --- a/apps/graph_executor/src/op_attr_types.h +++ b/apps/graph_executor/src/op_attr_types.h @@ -52,16 +52,34 @@ using FTVMSchedule = std::function< const Array& outs, const std::string& target)>; -// TODO better structure of layout +/*! + * \brief Layout transform information, + * from source layout to destination layout. + */ struct LayoutInfo { using Layout = std::string; Layout src; Layout dst; }; +/*! + * \brief Layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of inputs/outputs layout info. + */ using FTVMLayoutInfo = std::function< std::vector(const NodeAttrs& attrs)>; +/*! + * \brief Inputs layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of inputs layout info. + */ using FTVMInputsLayoutInfo = FTVMLayoutInfo; +/*! + * \brief Outputs layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of outputs layout info. + */ using FTVMOutputsLayoutInfo = FTVMLayoutInfo; // The storage result of op From bdddba66ef2e4628d7dc7a7777d76b2ea2fd3682 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 9 Jul 2017 16:04:11 -0700 Subject: [PATCH 3/3] Fix --- apps/graph_executor/src/graph_pass.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/apps/graph_executor/src/graph_pass.cc b/apps/graph_executor/src/graph_pass.cc index 1f6fbc20ef4a..4c5f9541317c 100644 --- a/apps/graph_executor/src/graph_pass.cc +++ b/apps/graph_executor/src/graph_pass.cc @@ -421,8 +421,8 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { static auto& olayouts = nnvm::Op::GetAttr("FTVMOutputsLayoutInfo"); - nnvm::NodeEntryMap transformed; std::unordered_map mirror_map; + std::unordered_map > transformed; DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { nnvm::NodePtr new_node = nnvm::Node::Create(); @@ -433,14 +433,13 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { } if (olayouts.count(n->op())) { + std::vector tnodes(n->num_outputs(), nullptr); for (uint32_t i = 0; i < n->num_outputs(); ++i) { LayoutInfo layout = GetLayout(olayouts, n, i); - nnvm::NodePtr tnode = - CreateLayoutTransformNode(layout.src, layout.dst); - tnode->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); - transformed.emplace( - nnvm::NodeEntry{n, i, 0}, nnvm::NodeEntry{tnode, 0, 0}); + tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst); + tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); } + transformed.emplace(n.get(), std::move(tnodes)); } for (size_t idx = 0; idx < n->inputs.size(); ++idx) { @@ -460,7 +459,9 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { } if (otrans) { - new_node->inputs[idx] = transformed.at(e); + const auto& tnodes = transformed.at(in.get()); + new_node->inputs[idx] = + nnvm::NodeEntry{tnodes[e.index], 0, 0}; } if (itrans) { @@ -477,7 +478,8 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { std::vector outputs; for (const auto& e : src.outputs) { if (olayouts.count(e.node->op())) { - outputs.emplace_back(transformed.at(e)); + const auto& tnodes = transformed.at(e.node.get()); + outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0}); } else { outputs.emplace_back( nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});