Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,37 @@ MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str);
/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param key The key of the symbol.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute from symbol.
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
*
* Safe recommendaton: use immutable graph
* - Only allow set attributes during creation of new symbol as optional parameter
*
* Mutable graph (be careful about the semantics):
* - Allow set attr at any point.
* - Mutating an attribute of some common node of two graphs can cause confusion from user.
*
* \param symbol the source symbol
* \param key The key of the symbol.
* \param value The value to be saved.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol,
const char* key,
const char* value);
/*!
* \brief List arguments in the symbol.
* \param symbol the symbol
Expand Down
17 changes: 17 additions & 0 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@ class Symbol {
*/
void Compose(const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name);
/*!
* \brief set additional attributes of the symbol,
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* \param key the key of the attribute
* \param value the value of the attribute.
*/
void SetAttr(const std::string &key, const std::string& value);
/*!
* \brief Get attributes from the symbol.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* \param key Key of the attribute.
* \param out the output value of the attribute.
* \return true if the attribute exists, false if the attribute do not exist.
*/
bool GetAttr(const std::string& key, std::string* out);
/*!
* \brief Apply the symbol as a function, compose with arguments
* \param args positional arguments for the symbol
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
# use mx.kv as short for kvstore
from . import kvstore as kv
from . import kvstore_server
# Runtime compile module
from .rtc import Rtc as rtc
# Attribute scope to add attributes to symbolic graphs
from .attribute import AttrScope


__version__ = base.__version__
62 changes: 62 additions & 0 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import

from .base import string_types

class AttrScope(object):
"""Attribute manager for scoping.

User can also inheritate this object to change naming behavior.

Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None

def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs

def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.

Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.

Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
else:
return attr

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope

AttrScope.current = AttrScope()

52 changes: 49 additions & 3 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .base import NDArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call, ctypes2docstring
from .name import NameManager
from .attribute import AttrScope
from .context import Context
from .ndarray import NDArray, zeros
from .executor import Executor
Expand Down Expand Up @@ -199,6 +200,42 @@ def __getitem__(self, index):
self.handle, mx_uint(index), ctypes.byref(handle)))
return Symbol(handle=handle)

def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.

Parameters
----------
key : str
The key to get attribute from.

Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.MXSymbolGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None

def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.

Parameters
----------
**kwargs
The attributes to set
"""
for key, value in kwargs.items():
if not isinstance(value, string_types):
raise ValueError("Set Attr only accepts string values")
check_call(_LIB.MXSymbolSetAttr(
self.handle, c_str(key), c_str(str(value))))

def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.

Expand Down Expand Up @@ -630,13 +667,15 @@ def grad(self, wrt):
# pylint: enable= no-member


def Variable(name):
def Variable(name, attr=None):
"""Create a symbolic variable with specified name.

Parameters
----------
name : str
Name of the variable.
attr : dict of string -> string
Additional attributes to set on the variable.

Returns
-------
Expand All @@ -647,7 +686,11 @@ def Variable(name):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
return Symbol(handle)
ret = Symbol(handle)
attr = AttrScope.current.get(attr)
if attr:
ret._set_attr(**attr)
return ret


def Group(symbols):
Expand Down Expand Up @@ -784,6 +827,7 @@ def creator(*args, **kwargs):
param_vals = []
symbol_kwargs = {}
name = kwargs.pop('name', None)
attr = kwargs.pop('attr', None)

if key_var_num_args and key_var_num_args not in kwargs:
param_keys.append(c_str(key_var_num_args))
Expand Down Expand Up @@ -813,8 +857,10 @@ def creator(*args, **kwargs):
raise ValueError('This function support variable length of Symbol arguments.\n' +
'Please pass all the input Symbols via positional arguments' +
' instead of keyword arguments.')

s = Symbol(sym_handle)
attr = AttrScope.current.get(attr)
if attr:
s._set_attr(**attr)
hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
Expand Down
26 changes: 26 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,32 @@ int MXSymbolPrint(SymbolHandle symbol, const char **out_str) {
API_END();
}

int MXSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int* success) {
Symbol *s = static_cast<Symbol*>(symbol);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}

int MXSymbolSetAttr(SymbolHandle symbol,
const char* key,
const char* value) {
Symbol *s = static_cast<Symbol*>(symbol);
API_BEGIN();
s->SetAttr(key, value);
API_END();
}

int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array) {
Expand Down
3 changes: 3 additions & 0 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,12 @@ void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const {
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("inputs", inputs);
writer->WriteObjectKeyValue("backward_source_id", backward_source_id);
if (attr.size() != 0) writer->WriteObjectKeyValue("attr", attr);
writer->EndObject();
}

void StaticGraph::Node::Load(dmlc::JSONReader *reader) {
attr.clear();
dmlc::JSONObjectReadHelper helper;
std::string op_type_str;
std::map<std::string, std::string> param;
Expand All @@ -335,6 +337,7 @@ void StaticGraph::Node::Load(dmlc::JSONReader *reader) {
helper.DeclareField("name", &name);
helper.DeclareField("inputs", &inputs);
helper.DeclareField("backward_source_id", &backward_source_id);
helper.DeclareOptionalField("attr", &attr);
helper.ReadAllFields(reader);

if (op_type_str != "null") {
Expand Down
26 changes: 14 additions & 12 deletions src/symbol/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <algorithm>
#include <utility>
#include <vector>
#include <map>

namespace mxnet {
/*!
Expand Down Expand Up @@ -109,23 +110,24 @@ class StaticGraph {
* When the node is a Backward node, the op field will be nullptr
*/
int32_t backward_source_id;
/*! \brief additional attributes about the node */
std::map<std::string, std::string> attr;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}

friend void swap(Node& lhs, Node& rhs) {
std::swap(lhs.op, rhs.op);
std::swap(lhs.name, rhs.name);
std::swap(lhs.inputs, rhs.inputs);
std::swap(lhs.backward_source_id, rhs.backward_source_id);
}
/*! \brief copy constructor in favor of serialization. */
Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id) {}
Node(const Node& another)
: op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id),
attr(another.attr) {}

inline Node& operator=(Node another) {
swap(*this, another);
op = std::move(another.op);
name = std::move(another.name);
inputs = std::move(another.inputs);
backward_source_id = std::move(another.backward_source_id);
attr = std::move(another.attr);
return *this;
}
/*! \return whether the node is forward op node */
Expand Down
Loading