From d0fb0d3712054c4f13f01c9fc98b9dc97e6dbec9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 22 Nov 2015 12:46:02 -0800 Subject: [PATCH] [SYMBOL] enable attributes in graph node --- dmlc-core | 2 +- include/mxnet/c_api.h | 31 ++++++++++++ include/mxnet/symbolic.h | 17 +++++++ python/mxnet/__init__.py | 4 ++ python/mxnet/attribute.py | 62 +++++++++++++++++++++++ python/mxnet/symbol.py | 52 +++++++++++++++++-- src/c_api/c_api.cc | 26 ++++++++++ src/symbol/static_graph.cc | 3 ++ src/symbol/static_graph.h | 26 +++++----- src/symbol/symbol.cc | 80 ++++++++++++++++++++++++------ tests/python/unittest/test_attr.py | 33 ++++++++++++ 11 files changed, 306 insertions(+), 30 deletions(-) create mode 100644 python/mxnet/attribute.py create mode 100644 tests/python/unittest/test_attr.py diff --git a/dmlc-core b/dmlc-core index 6750e79201e5..4b951c037838 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 6750e79201e568e2a46d531c3fab7f6e31abb562 +Subproject commit 4b951c0378386b7f4d9eae72be2ecd3b9c816afe diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 247d05db072e..4f8d140df65f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -445,6 +445,37 @@ MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); +/*! + * \brief Get string attribute from symbol + * \param symbol the source symbol + * \param key The key of the symbol. + * \param out The result attribute, can be NULL if the attribute do not exist. + * \param success Whether the result is contained in out. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol, + const char* key, + const char** out, + int *success); +/*! + * \brief Set string attribute from symbol. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * + * Safe recommendaton: use immutable graph + * - Only allow set attributes during creation of new symbol as optional parameter + * + * Mutable graph (be careful about the semantics): + * - Allow set attr at any point. + * - Mutating an attribute of some common node of two graphs can cause confusion from user. + * + * \param symbol the source symbol + * \param key The key of the symbol. + * \param value The value to be saved. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol, + const char* key, + const char* value); /*! * \brief List arguments in the symbol. * \param symbol the symbol diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index c3f6d05dbb9a..c50be94a9695 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -86,6 +86,23 @@ class Symbol { */ void Compose(const std::unordered_map& kwargs, const std::string& name); + /*! + * \brief set additional attributes of the symbol, + * This only works for symbol with outputs from single operators. + * For grouped sybmbol, an error will be raised. + * \param key the key of the attribute + * \param value the value of the attribute. + */ + void SetAttr(const std::string &key, const std::string& value); + /*! + * \brief Get attributes from the symbol. + * This only works for symbol with outputs from single operators. + * For grouped sybmbol, an error will be raised. + * \param key Key of the attribute. + * \param out the output value of the attribute. + * \return true if the attribute exists, false if the attribute do not exist. + */ + bool GetAttr(const std::string& key, std::string* out); /*! * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 69e209033e3e..054d241a044c 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -33,6 +33,10 @@ # use mx.kv as short for kvstore from . import kvstore as kv from . import kvstore_server +# Runtime compile module from .rtc import Rtc as rtc +# Attribute scope to add attributes to symbolic graphs +from .attribute import AttrScope + __version__ = base.__version__ diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py new file mode 100644 index 000000000000..1685f48b4749 --- /dev/null +++ b/python/mxnet/attribute.py @@ -0,0 +1,62 @@ +# coding: utf-8 +"""Attribute scoping support for symbolic API.""" +from __future__ import absolute_import + +from .base import string_types + +class AttrScope(object): + """Attribute manager for scoping. + + User can also inheritate this object to change naming behavior. + + Parameters + ---------- + kwargs + The attributes to set for all symbol creations in the scope. + """ + current = None + + def __init__(self, **kwargs): + self._old_scope = None + for value in kwargs.values(): + if not isinstance(value, string_types): + raise ValueError("Attributes need to be string") + self._attr = kwargs + + def get(self, attr): + """ + Get the attribute dict given the attribute set by the symbol. + + Parameters + ---------- + attr : dict of string to string + The attribute passed in by user during symbol creation. + + Returns + ------- + attr : dict of string to string + Updated attributes to add other scope related attributes. + """ + if self._attr: + ret = self._attr.copy() + if attr: + ret.update(attr) + return ret + else: + return attr + + def __enter__(self): + # pylint: disable=protected-access + self._old_scope = AttrScope.current + attr = AttrScope.current._attr.copy() + attr.update(self._attr) + self._attr = attr + AttrScope.current = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_scope + AttrScope.current = self._old_scope + +AttrScope.current = AttrScope() + diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index df9a27926d94..c49ab39f561c 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -11,6 +11,7 @@ from .base import NDArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call, ctypes2docstring from .name import NameManager +from .attribute import AttrScope from .context import Context from .ndarray import NDArray, zeros from .executor import Executor @@ -199,6 +200,42 @@ def __getitem__(self, index): self.handle, mx_uint(index), ctypes.byref(handle))) return Symbol(handle=handle) + def attr(self, key): + """Get attribute string from the symbol, this function only works for non-grouped symbol. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + ret = ctypes.c_char_p() + success = ctypes.c_int() + check_call(_LIB.MXSymbolGetAttr( + self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) + if success.value != 0: + return py_str(ret.value) + else: + return None + + def _set_attr(self, **kwargs): + """Set the attribute of the symbol. + + Parameters + ---------- + **kwargs + The attributes to set + """ + for key, value in kwargs.items(): + if not isinstance(value, string_types): + raise ValueError("Set Attr only accepts string values") + check_call(_LIB.MXSymbolSetAttr( + self.handle, c_str(key), c_str(str(value)))) + def get_internals(self): """Get a new grouped symbol whose output contains all the internal outputs of this symbol. @@ -630,13 +667,15 @@ def grad(self, wrt): # pylint: enable= no-member -def Variable(name): +def Variable(name, attr=None): """Create a symbolic variable with specified name. Parameters ---------- name : str Name of the variable. + attr : dict of string -> string + Additional attributes to set on the variable. Returns ------- @@ -647,7 +686,11 @@ def Variable(name): raise TypeError('Expect a string for variable `name`') handle = SymbolHandle() check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) - return Symbol(handle) + ret = Symbol(handle) + attr = AttrScope.current.get(attr) + if attr: + ret._set_attr(**attr) + return ret def Group(symbols): @@ -784,6 +827,7 @@ def creator(*args, **kwargs): param_vals = [] symbol_kwargs = {} name = kwargs.pop('name', None) + attr = kwargs.pop('attr', None) if key_var_num_args and key_var_num_args not in kwargs: param_keys.append(c_str(key_var_num_args)) @@ -813,8 +857,10 @@ def creator(*args, **kwargs): raise ValueError('This function support variable length of Symbol arguments.\n' + 'Please pass all the input Symbols via positional arguments' + ' instead of keyword arguments.') - s = Symbol(sym_handle) + attr = AttrScope.current.get(attr) + if attr: + s._set_attr(**attr) hint = func_name.lower() name = NameManager.current.get(name, hint) s._compose(*args, name=name, **symbol_kwargs) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index e1228f2d3ecc..099627c1a3a4 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -532,6 +532,32 @@ int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { API_END(); } +int MXSymbolGetAttr(SymbolHandle symbol, + const char* key, + const char** out, + int* success) { + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + if (s->GetAttr(key, &(ret->ret_str))) { + *out = (ret->ret_str).c_str(); + *success = 1; + } else { + *out = nullptr; + *success = 0; + } + API_END(); +} + +int MXSymbolSetAttr(SymbolHandle symbol, + const char* key, + const char* value) { + Symbol *s = static_cast(symbol); + API_BEGIN(); + s->SetAttr(key, value); + API_END(); +} + int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index d575984dd1f9..07107915cd52 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -323,10 +323,12 @@ void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("backward_source_id", backward_source_id); + if (attr.size() != 0) writer->WriteObjectKeyValue("attr", attr); writer->EndObject(); } void StaticGraph::Node::Load(dmlc::JSONReader *reader) { + attr.clear(); dmlc::JSONObjectReadHelper helper; std::string op_type_str; std::map param; @@ -335,6 +337,7 @@ void StaticGraph::Node::Load(dmlc::JSONReader *reader) { helper.DeclareField("name", &name); helper.DeclareField("inputs", &inputs); helper.DeclareField("backward_source_id", &backward_source_id); + helper.DeclareOptionalField("attr", &attr); helper.ReadAllFields(reader); if (op_type_str != "null") { diff --git a/src/symbol/static_graph.h b/src/symbol/static_graph.h index 514a8f6d80a0..0d3257c0b06b 100644 --- a/src/symbol/static_graph.h +++ b/src/symbol/static_graph.h @@ -16,6 +16,7 @@ #include #include #include +#include namespace mxnet { /*! @@ -109,23 +110,24 @@ class StaticGraph { * When the node is a Backward node, the op field will be nullptr */ int32_t backward_source_id; + /*! \brief additional attributes about the node */ + std::map attr; /*! \brief default constructor */ Node() : backward_source_id(-1) {} - - friend void swap(Node& lhs, Node& rhs) { - std::swap(lhs.op, rhs.op); - std::swap(lhs.name, rhs.name); - std::swap(lhs.inputs, rhs.inputs); - std::swap(lhs.backward_source_id, rhs.backward_source_id); - } /*! \brief copy constructor in favor of serialization. */ - Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), - name(another.name), - inputs(another.inputs), - backward_source_id(another.backward_source_id) {} + Node(const Node& another) + : op(another.op.get() ? another.op.get()->Copy() : nullptr), + name(another.name), + inputs(another.inputs), + backward_source_id(another.backward_source_id), + attr(another.attr) {} inline Node& operator=(Node another) { - swap(*this, another); + op = std::move(another.op); + name = std::move(another.name); + inputs = std::move(another.inputs); + backward_source_id = std::move(another.backward_source_id); + attr = std::move(another.attr); return *this; } /*! \return whether the node is forward op node */ diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 45255f095985..5027ee7a52ca 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -22,21 +22,39 @@ namespace mxnet { * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { - /*! \brief source node of the current node */ - std::shared_ptr backward_source_node; /*! \brief Operator of this node */ std::unique_ptr op; /*! \brief name of the node */ std::string name; /*! \brief inputs to this node */ std::vector inputs; + /*! \brief source node of the current node */ + std::shared_ptr backward_source_node; + /*! + * \brief additional attributes about the node, + * Use pointer to save space, as attr can be accessed in a slow way, + * not every node will have attributes. + */ + std::unique_ptr > attr; /*! *\brief constructor *\param op the OperatorProperty to construct the Node *\param name the name of the symbol */ - explicit Node(OperatorProperty *op = nullptr, const std::string& name = "") - : op(op), name(name) { + explicit Node(OperatorProperty *op, + const std::string& name) + : op(op), name(name) {} + /*! + *\brief copy constructor constructor + */ + explicit Node(const Node& other) + : name(other.name) { + if (other.op != nullptr) { + op.reset(other.op->Copy()); + } + if (other.attr.get() != nullptr) { + attr.reset(new std::map(*(other.attr))); + } } /*! \return Whether the symbol is atomic */ inline bool is_atomic() const { @@ -129,11 +147,7 @@ Symbol Symbol::Copy() const { std::unordered_map > old_new; // use DFSVisit to copy all the nodes this->DFSVisit([&old_new](const std::shared_ptr &node) { - if (node->op == nullptr) { - old_new[node.get()] = std::make_shared(nullptr, node->name); - } else { - old_new[node.get()] = std::make_shared(node->op->Copy(), node->name); - } + old_new[node.get()] = std::make_shared(*node); }); // connect nodes of new graph for (const auto &kv : old_new) { @@ -310,6 +324,11 @@ void Symbol::Compose(const std::vector& args, for (size_t i = args.size(); i < req_args.size(); ++i) { heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); + // also copy attribute of operator over to automatically created variable + if (heads_[0].source->attr.get() != nullptr) { + heads_[0].source->inputs[i].source->attr.reset( + new std::map(*(heads_[0].source->attr))); + } } } else { // find all the place holders @@ -370,6 +389,11 @@ void Symbol::Compose(const std::unordered_map& kwargs, } else { heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); + // also copy attribute of operator over to automatically created variable + if (heads_[0].source->attr.get() != nullptr) { + heads_[0].source->inputs[i].source->attr.reset( + new std::map(*(heads_[0].source->attr))); + } } } // if things goes wrong recover the old state @@ -426,6 +450,31 @@ void Symbol::Compose(const std::unordered_map& kwargs, } } +void Symbol::SetAttr(const std::string &key, const std::string& value) { + Node* node = heads_[0].source.get(); + for (const DataEntry& e : heads_) { + CHECK(node == e.source.get()) + << "Symbol.SetAttr only works for non-grouped symbol"; + } + if (node->attr.get() == nullptr) { + node->attr.reset(new std::map()); + } + (*node->attr)[key] = value; +} + +bool Symbol::GetAttr(const std::string& key, std::string* out) { + Node* node = heads_[0].source.get(); + for (const DataEntry& e : heads_) { + CHECK(node == e.source.get()) + << "Symbol.GetAttr only works for non-grouped symbol"; + } + if (node->attr.get() == nullptr) return false; + auto it = node->attr->find(key); + if (it == node->attr->end()) return false; + *out = it->second; + return true; +} + Symbol Symbol::operator () (const std::vector& args, const std::string& name) const { Symbol s = this->Copy(); @@ -453,8 +502,7 @@ Symbol Symbol::Grad(const std::vector& wrt) const { }); for (std::vector::const_iterator it = g.nodes.begin() + num_nodes; it != g.nodes.end(); ++it) { - auto sym_node = std::make_shared(); - sym_node->name = it->name; + auto sym_node = std::make_shared(nullptr, it->name); if (it->backward_source_id != -1) { sym_node->backward_source_node = shared_node[it->backward_source_id]; } @@ -557,7 +605,6 @@ Symbol Symbol::CreateVariable(const std::string &name) { } void Symbol::ToStaticGraph(StaticGraph *out_graph) const { - // TODO(bing): Check unique name std::vector node_order; std::unordered_map node_index; auto &arg_nodes = out_graph->arg_nodes; @@ -586,6 +633,9 @@ void Symbol::ToStaticGraph(StaticGraph *out_graph) const { } else { out_graph->nodes[nid].backward_source_id = -1; } + if (node_order[nid]->attr.get() != nullptr) { + out_graph->nodes[nid].attr = *(node_order[nid]->attr); + } out_graph->nodes[nid].name = node_order[nid]->name; auto &inputs = out_graph->nodes[nid].inputs; inputs.clear(); @@ -612,14 +662,16 @@ void Symbol::FromStaticGraph(const StaticGraph &graph) { // copy ver nodes in topo order for (uint32_t nid : topo_order) { auto &gnode = graph.nodes[nid]; - auto sym_node = std::make_shared(); - sym_node->name = gnode.name; + auto sym_node = std::make_shared(nullptr, gnode.name); if (gnode.op.get() != nullptr) { sym_node->op.reset(gnode.op->Copy()); } if (gnode.backward_source_id != -1) { sym_node->backward_source_node = nodes.at(gnode.backward_source_id); } + if (gnode.attr.size() != 0) { + sym_node->attr.reset(new std::map(gnode.attr)); + } for (const StaticGraph::DataEntry& e : gnode.inputs) { Symbol::DataEntry entry(nodes.at(e.source_id), e.index); sym_node->inputs.push_back(std::move(entry)); diff --git a/tests/python/unittest/test_attr.py b/tests/python/unittest/test_attr.py new file mode 100644 index 000000000000..b17e8ec49759 --- /dev/null +++ b/tests/python/unittest/test_attr.py @@ -0,0 +1,33 @@ +import os +import mxnet as mx +from common import models +import pickle as pkl + +def test_attr_basic(): + with mx.AttrScope(group='4', data='great'): + data = mx.symbol.Variable('data', + attr={'dtype':'data', + 'group': '1'}) + gdata = mx.symbol.Variable('data2') + assert gdata.attr('group') == '4' + assert data.attr('group') == '1' + data2 = pkl.loads(pkl.dumps(data)) + assert data.attr('dtype') == data2.attr('dtype') + +def test_operator(): + data = mx.symbol.Variable('data') + with mx.AttrScope(group='4', data='great'): + fc1 = mx.symbol.Activation(data, act_type='relu') + with mx.AttrScope(init_bias='0.0'): + fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2') + assert fc1.attr('data') == 'great' + fc2copy = pkl.loads(pkl.dumps(fc2)) + assert fc2copy.tojson() == fc2.tojson() + fc2weight = fc2.get_internals()['fc2_weight'] + + +if __name__ == '__main__': + test_attr_basic() + test_operator() + +