Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
506 changes: 506 additions & 0 deletions nnvm/python/nnvm/to_relay.py

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions nnvm/tests/python/compiler/test_to_relay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import nnvm
from nnvm import testing
from nnvm import to_relay
import tvm
from tvm.relay import ir_pass
from tvm.relay import create_executor
from tvm.contrib import graph_runtime
import numpy as np

def check_model(sym, shapes, dtypes, params):
net = nnvm.graph.create(sym)
graph_json, mod, params = nnvm.compiler.build(
net,
'llvm',
shape=shapes,
dtype=dtypes,
params=params)
nnvm_rts = graph_runtime.create(graph_json, mod, tvm.cpu(0))
inputs = {}
for name in shapes:
np_array = np.random.rand(*shapes[name]).astype('float32')
inputs[name] = tvm.nd.array(np_array)

nnvm_rts.set_input(**params)
nnvm_rts.run(**inputs)
nnvm_out = nnvm_rts.get_output(0)
relay_model, params = to_relay.to_relay(net, shapes, dtypes, params)
relay_model = ir_pass.infer_type(relay_model)
relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm')
inputs.update(params)
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())

# def test_mlp():
# mlp, params = testing.mlp.get_workload(1)
# shapes = { "data": (10, 3, 224, 224) }
# dtypes = { "data": 'float32' }
# check_model(mlp, shapes, dtypes, params)

if __name__ == "__main__":
test_mlp()
55 changes: 54 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,64 @@ def get_int_tuple(self, key, default=RequiredAttr()):
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()').split(','))
return tuple(int(x.strip()) for x in tshape.strip('()[]').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_tuple_tuple_int(self, key, default=RequiredAttr()):
"""Get int list attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
value = self.attrs[key]
seq = []
for tup in value.strip('()').split('),'):
tup = tup.strip('[]()')
els = [int(x.strip('( ')) for x in tup.split(',')]
seq.append(tuple(els))

return tuple(seq)

if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_int_list(self, key, default=RequiredAttr()):
"""Get int list attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('[]()').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default



def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute

Expand Down
137 changes: 7 additions & 130 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,138 +8,14 @@
from .. import op as _op
from ... import nd as _nd
from .common import StrAttrsDict
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling
from .nnvm_common import _elemwise_sum, _reshape
from .nnvm_common import _warn_not_used

__all__ = ['from_mxnet']


def _get_relay_op(op_name):
op = getattr(_op, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op


def _warn_not_used(attr, op='nnvm'):
import warnings
err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err)


def _rename(new_op):
if isinstance(new_op, str):
new_op = _get_relay_op(new_op)
# attrs are ignored.
def impl(inputs, _):
return new_op(*inputs)
return impl


def _reshape(inputs, attrs):
if attrs.get_bool("reverse", False):
raise RuntimeError("reshape do not support option reverse")
shape = attrs.get_int_tuple("shape")
return _op.reshape(inputs[0], newshape=shape)


def _init_op(new_op):
"""Init ops like zeros/ones"""
def _impl(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
return new_op(shape=shape, dtype=dtype)
return _impl


def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
return _impl


def _reduce(new_op):
"""Reduction ops like sum/min/max"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False)
# use None for reduce over all axis.
axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims)
return _impl


def _arg_reduce(new_op):
"""Arg Reduction ops like argmin/argmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", None)
keepdims = attrs.get_bool("keepdims", False)
res = new_op(inputs[0], axis=[axis], keepdims=keepdims)
# cast to dtype.
res = res.astype("float32")
return res
return _impl


def _cast(inputs, attrs):
"""Type cast"""
dtype = attrs.get_str("dtype")
return _op.cast(inputs[0], dtype=dtype)


def _clip(inputs, attrs):
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)


def _transpose(inputs, attrs):
axes = attrs.get_int_tuple("axes", None)
# translate default case
axes = None if len(axes) == 0 else axes
return _op.transpose(inputs[0], axes=axes)


def _upsampling(inputs, attrs):
scale = attrs.get_int("scale")
return _op.nn.upsampling(inputs[0], scale=scale)


def _elemwise_sum(inputs, _):
assert len(inputs) > 0
res = inputs[0]
for x in inputs[1:]:
res = _op.add(res, x)
return res


def _binop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(inputs[0], scalar)
return _impl


def _rbinop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(scalar, inputs[0])
return _impl

# All the functions with _mx prefix specific to MXNet.
# The functions without _mx prefix can be reused for
# NNVMv1 conversion to _op.

def _mx_fully_connected(inputs, attrs):
import mxnet as mx
units = attrs.get_int("num_hidden")
Expand Down Expand Up @@ -493,6 +369,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
jnodes = jgraph["nodes"]
node_map = {}


for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]]
attrs = StrAttrsDict(node.get("attrs", {}))
Expand All @@ -501,7 +378,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
if op_name == "null":
shape = shape_dict[node_name] if node_name in shape_dict else None
if isinstance(dtype_info, dict):
dtype = dtype_info[node_name] if node_name in dtype_dict else "float32"
dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
else:
dtype = dtype_info
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
Expand Down
Loading