diff --git a/include/nnvm/node.h b/include/nnvm/node.h index 2c7d0ef30..bcfdb95c4 100644 --- a/include/nnvm/node.h +++ b/include/nnvm/node.h @@ -18,6 +18,7 @@ namespace nnvm { // Forward declare node. class Node; +class Symbol; /*! * \brief we always used NodePtr for a reference pointer @@ -90,6 +91,14 @@ struct NodeAttrs { * The object can be used to quickly access attributes. */ any parsed; + /*! + * \brief Some operators take graphs as input. These operators include + * control flow operators and high-order functions. + * These graphs don't change when the operators are invoked for different + * mini-batches. In this sense, the subgraphs are kind of similar to + * the parameters and show be kept as node attributes. + */ + std::vector > subgraphs; }; /*! diff --git a/include/nnvm/op_attr_types.h b/include/nnvm/op_attr_types.h index e58e9ceb3..b7f6be408 100644 --- a/include/nnvm/op_attr_types.h +++ b/include/nnvm/op_attr_types.h @@ -202,6 +202,18 @@ using FCorrectLayout = std::function *last_ilayouts, std::vector *olayouts)>; +/*! + * \brief Get a list of inputs that represent graphs instead of data. + * Normally, input symbols are considered as data to the operator. However, + * control flow operators and high-order functions need to interpret symbols + * as graphs. + * \param attrs The attributes of this node. + * \return a list of input index that are interpreted as symbols by the operator. + * + * \note Register under "FInputGraph". + */ +using FInputGraph = std::function(const NodeAttrs& attrs)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index 2a2f5be50..927dd2b70 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -267,14 +267,36 @@ void Symbol::Compose(const array_view& args, const std::string& name) { static auto& flist_inputs = Op::GetAttr("FListInputNames"); static auto& fset_attrs = Op::GetAttr("FSetInputVarAttrOnCompose"); + static auto& fgraph = Op::GetAttr("FInputGraph"); + + // The arguments that contain graphs. + Node* n = outputs[0].node.get(); + FInputGraph fng = fgraph.get(n->op(), nullptr); + std::vector garg_idx; + if (fng != nullptr) + garg_idx = fng(n->attrs); + + // The names of the arguments that contain graphs. + FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); + auto arg_names = (name_fn == nullptr) ? std::vector{"data"} : name_fn(n->attrs); + std::vector garg_names(garg_idx.size()); + for (size_t i = 0; i < garg_idx.size(); i++) { + size_t idx = garg_idx[i]; + if (idx < arg_names.size()) + garg_names[i] = arg_names[idx]; + } // parameter check. for (size_t i = 0; i < args.size(); ++i) { - CHECK_EQ(args[i]->outputs.size(), 1U) + // If the argument isn't a graph, it should have only one output. + if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) + CHECK_EQ(args[i]->outputs.size(), 1U) << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - CHECK_EQ(kv.second->outputs.size(), 1U) + if (garg_names.empty() + || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + CHECK_EQ(kv.second->outputs.size(), 1U) << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name @@ -282,28 +304,49 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { - Node* n = outputs[0].node.get(); uint32_t n_req = n->num_inputs(); + std::vector arg_vec(args.begin(), args.end()); + std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); + // If one of the input arguments is a graph, we need to remove it from the + // list. + if (fng != nullptr) { + std::vector idxes = fng(n->attrs); + for (auto idx : idxes) { + const Symbol *sym; + if (idx < arg_vec.size()) { + sym = arg_vec[idx]; + arg_vec.erase(arg_vec.begin() + idx); + } else { + auto it = kwarg_map.find(arg_names[idx]); + CHECK(it != kwarg_map.end()); + sym = it->second; + kwarg_map.erase(it); + } + + if (n_req != kVarg) + n_req--; + arg_names.erase(arg_names.begin() + idx); + n->attrs.subgraphs.push_back(std::make_shared(*sym)); + } + } if (n_req != kVarg) { n->inputs.resize(n_req); - CHECK_LE(args.size(), n_req) + CHECK_LE(arg_vec.size(), n_req) << "Incorrect number of arguments, requires " << n_req - << ", provided " << args.size(); - for (size_t i = 0; i < args.size(); ++i) { - n->inputs[i] = args[i]->outputs[0]; + << ", provided " << arg_vec.size(); + for (size_t i = 0; i < arg_vec.size(); ++i) { + n->inputs[i] = arg_vec[i]->outputs[0]; } // switch to keyword argument matching - if (args.size() != n_req) { - FListInputNames fn = flist_inputs.get(n->op(), nullptr); - auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); + if (arg_vec.size() != n_req) { if (arg_names.size() != n_req) { LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name; } size_t nmatched = 0; - for (size_t i = args.size(); i < n_req; ++i) { - auto it = kwargs.find(arg_names[i]); - if (it != kwargs.end() && it->first == arg_names[i]) { + for (size_t i = arg_vec.size(); i < n_req; ++i) { + auto it = kwarg_map.find(arg_names[i]); + if (it != kwarg_map.end() && it->first == arg_names[i]) { n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view& args, } } - if (nmatched != kwargs.size()) { + if (nmatched != kwarg_map.size()) { n->inputs.clear(); - std::vector keys = GetKeys(kwargs); - array_view view(dmlc::BeginPtr(arg_names) + args.size(), + std::vector keys = GetKeys(kwarg_map); + array_view view(dmlc::BeginPtr(arg_names) + arg_vec.size(), dmlc::BeginPtr(arg_names) + arg_names.size()); KeywordArgumentMismatch("Symbol.Compose", keys, view); } } } else { - CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs"; - n->inputs.reserve(args.size()); - for (const Symbol* s : args) { + CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs"; + n->inputs.reserve(arg_vec.size()); + for (const Symbol* s : arg_vec) { n->inputs.push_back(s->outputs[0]); } }