Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions apps/graph_executor/src/graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,5 +387,115 @@ NNVM_REGISTER_OP(tvm_op)
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});


inline bool IsPair(LayoutInfo in, LayoutInfo out) {
if (in.src == out.dst && in.dst == out.src) return true;
return false;
}

inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a function called CombineLayout? vector - > vector

const nnvm::NodePtr& n, int idx) {
return layouts[n->op()](n->attrs)[idx];
}

nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src"] = src;
n->attrs.dict["dst"] = dst;
return n;
}


/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& ilayouts =
nnvm::Op::GetAttr<FTVMInputsLayoutInfo>("FTVMInputsLayoutInfo");
static auto& olayouts =
nnvm::Op::GetAttr<FTVMOutputsLayoutInfo>("FTVMOutputsLayoutInfo");

std::unordered_map<nnvm::Node*, nnvm::NodePtr> mirror_map;
std::unordered_map<nnvm::Node*, std::vector<nnvm::NodePtr> > transformed;

DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *n;
if (new_node->is_variable()) {
mirror_map[n.get()] = new_node;
return;
}

if (olayouts.count(n->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
for (uint32_t i = 0; i < n->num_outputs(); ++i) {
LayoutInfo layout = GetLayout(olayouts, n, i);
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
transformed.emplace(n.get(), std::move(tnodes));
}

for (size_t idx = 0; idx < n->inputs.size(); ++idx) {
const nnvm::NodeEntry& e = n->inputs[idx];
const nnvm::NodePtr& in = e.node;
new_node->inputs[idx] =
nnvm::NodeEntry{mirror_map.at(in.get()), e.index, e.version};

bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(n->op());
if (otrans && itrans) {
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
LayoutInfo ilayout = GetLayout(ilayouts, n, idx);
if (IsPair(olayout, ilayout)) {
continue;
}
}

if (otrans) {
const auto& tnodes = transformed.at(in.get());
new_node->inputs[idx] =
nnvm::NodeEntry{tnodes[e.index], 0, 0};
}

if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, n, idx);
nnvm::NodePtr tnode =
CreateLayoutTransformNode(layout.src, layout.dst);
tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_map[n.get()] = std::move(new_node);
});

std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (olayouts.count(e.node->op())) {
const auto& tnodes = transformed.at(e.node.get());
outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});
}
}

nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}

NNVM_REGISTER_PASS(LayoutTransform)
.set_body(LayoutTransform);

NNVM_REGISTER_OP(layout_transform)
.set_num_inputs(1)
.set_num_outputs(1);
} // namespace contrib
} // namespace tvm
30 changes: 30 additions & 0 deletions apps/graph_executor/src/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs,
const std::string& target)>;

/*!
* \brief Layout transform information,
* from source layout to destination layout.
*/
struct LayoutInfo {
using Layout = std::string;
Layout src;
Layout dst;
};

/*!
* \brief Layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of inputs/outputs layout info.
*/
using FTVMLayoutInfo = std::function<
std::vector<LayoutInfo>(const NodeAttrs& attrs)>;
/*!
* \brief Inputs layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of inputs layout info.
*/
using FTVMInputsLayoutInfo = FTVMLayoutInfo;
/*!
* \brief Outputs layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of outputs layout info.
*/
using FTVMOutputsLayoutInfo = FTVMLayoutInfo;

// The storage result of op
enum OpPatternKind : int {
// Elementwise operation
Expand Down