From 3bfef6425aad7070e92009ab559a61ed4fbb2f40 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 24 Feb 2018 19:03:13 -0800 Subject: [PATCH 1/3] MXNet NDArray bridge. Support convert a tvm Function as MXNet's async NDArray function. --- python/tvm/contrib/mxnet.py | 57 +++++++++++++++++++++++ src/api/api_base.cc | 5 ++ tests/python/contrib/test_mxnet_bridge.py | 50 ++++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 python/tvm/contrib/mxnet.py create mode 100644 tests/python/contrib/test_mxnet_bridge.py diff --git a/python/tvm/contrib/mxnet.py b/python/tvm/contrib/mxnet.py new file mode 100644 index 000000000000..1c5da1571f55 --- /dev/null +++ b/python/tvm/contrib/mxnet.py @@ -0,0 +1,57 @@ +"""MXNet bridge wrap Function MXNet's async function.""" +from __future__ import absolute_import as _abs + +from .. import api, _api_internal, ndarray +from ..module import Module + +_wrap_async = None + + +def to_mxnet_func(func, const_loc=None): + """Wrap a TVM function as MXNet function + + MXNet function runs asynchrously via its engine. + + Parameters + ---------- + func : Function + A TVM function that can take positional arguments + + const_loc : list of int + List of integers indicating the argument position + of read only NDArray argument. + The NDArray argument location that are not annotated + will be viewed as mutable arrays in MXNet's engine. + + Returns + ------- + async_func : Function + A function that can take MXNet NDArray as argument + in places that used to expect TVM NDArray. + Run asynchrously in MXNet's async engine. + """ + # only import mxnet when wrap get called. + import mxnet + if isinstance(func, Module): + func = func.entry_func + + def _get_bridge_func(): + """Get MXNet bridge function""" + if not mxnet.base._LIB.MXTVMBridge: + raise RuntimeError( + "MXTVMBridge not exist in mxnet package," + " please update to latest version") + + fdict = api.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge) + ret = fdict["WrapAsyncCall"] + ret.is_global = True + return ret + global _wrap_async + + if _wrap_async is None: + # Register extension type in first time + _wrap_async = _get_bridge_func() + ndarray.register_extension(mxnet.nd.NDArray) + + const_loc = const_loc if const_loc else [] + return _wrap_async(func, _api_internal._TVMSetStream, len(const_loc), *const_loc) diff --git a/src/api/api_base.cc b/src/api/api_base.cc index df8469903533..cc76f6a8f50b 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -36,4 +36,9 @@ TVM_REGISTER_API("_load_json") TVM_REGISTER_API("_nop") .set_body([](TVMArgs args, TVMRetValue *ret) { }); + +TVM_REGISTER_API("_TVMSetStream") +.set_body([](TVMArgs args, TVMRetValue *ret) { + TVMSetStream(args[0], args[1], args[2]); + }); } // namespace tvm diff --git a/tests/python/contrib/test_mxnet_bridge.py b/tests/python/contrib/test_mxnet_bridge.py new file mode 100644 index 000000000000..9c590bc3a40e --- /dev/null +++ b/tests/python/contrib/test_mxnet_bridge.py @@ -0,0 +1,50 @@ + + +def mxnet_check(): + """This is a simple test function for RPC Proxy + + It is not included as nosetests, because of its dependency on mxnet + + User can directly run this script to verify correctness. + """ + import mxnet as mx + import topi + import tvm + import numpy as np + from tvm.contrib.mxnet import to_mxnet_func + + # build a TVM function through topi + n = 20 + shape = (20,) + scale = tvm.var("scale", dtype="float32") + x = tvm.placeholder(shape) + y = tvm.placeholder(shape) + z = topi.broadcast_add(x, y) + zz = tvm.compute(shape, lambda *i: z(*i) * scale) + + target = tvm.target.cuda() + + # build the function + with target: + s = topi.generic.schedule_injective(zz) + f = tvm.build(s, [x, y, zz, scale]) + + # get a mxnet version + mxf = to_mxnet_func(f, const_loc=[0, 1]) + + ctx = mx.gpu(0) + xx = mx.nd.uniform(shape=shape, ctx=ctx) + yy = mx.nd.uniform(shape=shape, ctx=ctx) + zz = mx.nd.empty(shape=shape, ctx=ctx) + + # invoke myf: this runs in mxnet engine + mxf(xx, yy, zz, 10.0) + mxf(xx, yy, zz, 10.0) + + + np.testing.assert_allclose( + zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10) + + +if __name__ == "__main__": + mxnet_check() From 77f6f665b16b2631bd030e9133424fcb7b7fcb78 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 24 Feb 2018 21:24:24 -0800 Subject: [PATCH 2/3] fix lint --- python/tvm/contrib/mxnet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/contrib/mxnet.py b/python/tvm/contrib/mxnet.py index 1c5da1571f55..3a6c92f1b880 100644 --- a/python/tvm/contrib/mxnet.py +++ b/python/tvm/contrib/mxnet.py @@ -4,6 +4,7 @@ from .. import api, _api_internal, ndarray from ..module import Module +# pylint: disable=invalid-name _wrap_async = None @@ -31,6 +32,7 @@ def to_mxnet_func(func, const_loc=None): Run asynchrously in MXNet's async engine. """ # only import mxnet when wrap get called. + # pylint: disable=import-self import mxnet if isinstance(func, Module): func = func.entry_func From c1151420ac1376e2f4502564d8da5108089ae37d Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 24 Feb 2018 21:29:21 -0800 Subject: [PATCH 3/3] update comment --- tests/python/contrib/test_mxnet_bridge.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/contrib/test_mxnet_bridge.py b/tests/python/contrib/test_mxnet_bridge.py index 9c590bc3a40e..2228f7305c6b 100644 --- a/tests/python/contrib/test_mxnet_bridge.py +++ b/tests/python/contrib/test_mxnet_bridge.py @@ -1,7 +1,5 @@ - - def mxnet_check(): - """This is a simple test function for RPC Proxy + """This is a simple test function for MXNet bridge It is not included as nosetests, because of its dependency on mxnet