Skip to content
This repository was archived by the owner on Feb 1, 2020. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace nnvm {

// Forward declare node.
class Node;
class Symbol;

/*!
* \brief we always used NodePtr for a reference pointer
Expand Down Expand Up @@ -90,6 +91,14 @@ struct NodeAttrs {
* The object can be used to quickly access attributes.
*/
any parsed;
/*!
* \brief Some operators take graphs as input. These operators include
* control flow operators and high-order functions.
* These graphs don't change when the operators are invoked for different
* mini-batches. In this sense, the subgraphs are kind of similar to
* the parameters and show be kept as node attributes.
*/
std::vector<std::shared_ptr<Symbol> > subgraphs;
};

/*!
Expand Down
12 changes: 12 additions & 0 deletions include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ using FCorrectLayout = std::function<bool(
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;

/*!
* \brief Get a list of inputs that represent graphs instead of data.
* Normally, input symbols are considered as data to the operator. However,
* control flow operators and high-order functions need to interpret symbols
* as graphs.
* \param attrs The attributes of this node.
* \return a list of input index that are interpreted as symbols by the operator.
*
* \note Register under "FInputGraph".
*/
using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
81 changes: 62 additions & 19 deletions src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,43 +267,86 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");

// The arguments that contain graphs.
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
if (fng != nullptr)
garg_idx = fng(n->attrs);

// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
if (idx < arg_names.size())
garg_names[i] = arg_names[idx];
}

// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
if (garg_names.empty()
|| std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;

// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
std::vector<const Symbol *> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
const Symbol *sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
arg_vec.erase(arg_vec.begin() + idx);
} else {
auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end());
sym = it->second;
kwarg_map.erase(it);
}

if (n_req != kVarg)
n_req--;
arg_names.erase(arg_names.begin() + idx);
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
}

if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
CHECK_LE(arg_vec.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i]->outputs[0];
<< ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_vec.size() != n_req) {
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
for (size_t i = arg_vec.size(); i < n_req; ++i) {
auto it = kwarg_map.find(arg_names[i]);
if (it != kwarg_map.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
Expand All @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
}

if (nmatched != kwargs.size()) {
if (nmatched != kwarg_map.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
std::vector<std::string> keys = GetKeys(kwarg_map);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol* s : args) {
CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(arg_vec.size());
for (const Symbol* s : arg_vec) {
n->inputs.push_back(s->outputs[0]);
}
}
Expand Down