diff --git a/apps/graph_executor/src/graph_pass.cc b/apps/graph_executor/src/graph_pass.cc index 5df34c6594cf..4c5f9541317c 100644 --- a/apps/graph_executor/src/graph_pass.cc +++ b/apps/graph_executor/src/graph_pass.cc @@ -387,5 +387,115 @@ NNVM_REGISTER_OP(tvm_op) const TVMOpParam& param = nnvm::get(attrs.parsed); return param.num_outputs; }); + + +inline bool IsPair(LayoutInfo in, LayoutInfo out) { + if (in.src == out.dst && in.dst == out.src) return true; + return false; +} + +inline LayoutInfo GetLayout(const nnvm::OpMap& layouts, + const nnvm::NodePtr& n, int idx) { + return layouts[n->op()](n->attrs)[idx]; +} + +nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) { + static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform"); + static int count = 0; + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = trans_op; + n->attrs.name = src + "_to_" + dst + std::to_string(count++); + n->attrs.dict["src"] = src; + n->attrs.dict["dst"] = dst; + return n; +} + + +/*! + * \brief A simple layout transform pass that will + * insert layout transform nodes automatically. + */ +nnvm::Graph LayoutTransform(nnvm::Graph src) { + static auto& ilayouts = + nnvm::Op::GetAttr("FTVMInputsLayoutInfo"); + static auto& olayouts = + nnvm::Op::GetAttr("FTVMOutputsLayoutInfo"); + + std::unordered_map mirror_map; + std::unordered_map > transformed; + + DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { + nnvm::NodePtr new_node = nnvm::Node::Create(); + *new_node = *n; + if (new_node->is_variable()) { + mirror_map[n.get()] = new_node; + return; + } + + if (olayouts.count(n->op())) { + std::vector tnodes(n->num_outputs(), nullptr); + for (uint32_t i = 0; i < n->num_outputs(); ++i) { + LayoutInfo layout = GetLayout(olayouts, n, i); + tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst); + tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); + } + transformed.emplace(n.get(), std::move(tnodes)); + } + + for (size_t idx = 0; idx < n->inputs.size(); ++idx) { + const nnvm::NodeEntry& e = n->inputs[idx]; + const nnvm::NodePtr& in = e.node; + new_node->inputs[idx] = + nnvm::NodeEntry{mirror_map.at(in.get()), e.index, e.version}; + + bool otrans = olayouts.count(in->op()); + bool itrans = ilayouts.count(n->op()); + if (otrans && itrans) { + LayoutInfo olayout = GetLayout(olayouts, in, e.index); + LayoutInfo ilayout = GetLayout(ilayouts, n, idx); + if (IsPair(olayout, ilayout)) { + continue; + } + } + + if (otrans) { + const auto& tnodes = transformed.at(in.get()); + new_node->inputs[idx] = + nnvm::NodeEntry{tnodes[e.index], 0, 0}; + } + + if (itrans) { + LayoutInfo layout = GetLayout(ilayouts, n, idx); + nnvm::NodePtr tnode = + CreateLayoutTransformNode(layout.src, layout.dst); + tnode->inputs.emplace_back(new_node->inputs[idx]); + new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0}; + } + } + mirror_map[n.get()] = std::move(new_node); + }); + + std::vector outputs; + for (const auto& e : src.outputs) { + if (olayouts.count(e.node->op())) { + const auto& tnodes = transformed.at(e.node.get()); + outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0}); + } else { + outputs.emplace_back( + nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); + } + } + + nnvm::Graph ret; + ret.outputs = std::move(outputs); + return ret; +} + +NNVM_REGISTER_PASS(LayoutTransform) +.set_body(LayoutTransform); + +NNVM_REGISTER_OP(layout_transform) +.set_num_inputs(1) +.set_num_outputs(1); } // namespace contrib } // namespace tvm diff --git a/apps/graph_executor/src/op_attr_types.h b/apps/graph_executor/src/op_attr_types.h index c7b4a55e5eba..1012b268c98c 100644 --- a/apps/graph_executor/src/op_attr_types.h +++ b/apps/graph_executor/src/op_attr_types.h @@ -52,6 +52,36 @@ using FTVMSchedule = std::function< const Array& outs, const std::string& target)>; +/*! + * \brief Layout transform information, + * from source layout to destination layout. + */ +struct LayoutInfo { + using Layout = std::string; + Layout src; + Layout dst; +}; + +/*! + * \brief Layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of inputs/outputs layout info. + */ +using FTVMLayoutInfo = std::function< + std::vector(const NodeAttrs& attrs)>; +/*! + * \brief Inputs layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of inputs layout info. + */ +using FTVMInputsLayoutInfo = FTVMLayoutInfo; +/*! + * \brief Outputs layout info of the node. + * \param attrs The attribute of the node. + * \return layouts A vector of outputs layout info. + */ +using FTVMOutputsLayoutInfo = FTVMLayoutInfo; + // The storage result of op enum OpPatternKind : int { // Elementwise operation