diff --git a/.gitignore b/.gitignore index 50376c5fea53..122ddb4097f2 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,4 @@ ENV/ *.pyc *~ build +config.mk diff --git a/Makefile b/Makefile index 86368ab2038f..514ffa665491 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,12 @@ -export LDFLAGS = -pthread -lm -export CFLAGS = -std=c++11 -Wall -O2\ - -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC +ifndef config +ifneq ("$(wildcard ./config.mk)","") + config = config.mk +else + config = make/config.mk +endif +endif + +include $(config) # specify tensor path .PHONY: clean all test doc @@ -13,6 +19,30 @@ SRC = $(wildcard src/*.cc src/*/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) +ifneq ($(USE_CUDA_PATH), NONE) + NVCC=$(USE_CUDA_PATH)/bin/nvcc +endif + +export LDFLAGS = -pthread -lm +export CFLAGS = -std=c++11 -Wall -O2\ + -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC + +ifneq ($(ADD_CFLAGS), NONE) + CFLAGS += $(ADD_CFLAGS) +endif + +ifneq ($(ADD_LDFLAGS), NONE) + LDFLAGS += $(ADD_LDFLAGS) +endif + + +ifeq ($(USE_CUDA), 1) + CFLAGS += -DTVM_CUDA_RUNTIME=1 + LDFLAGS += -lcuda -lcudart +else + CFLAGS += -DTVM_CUDA_RUNTIME=0 +endif + include tests/cpp/unittest.mk test: $(TEST) diff --git a/include/tvm/base.h b/include/tvm/base.h index 47ed2861a6aa..c9c95f3c9bc2 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -17,6 +17,20 @@ namespace tvm { +/*! + *\brief whether to use CUDA runtime + */ +#ifndef TVM_CUDA_RUNTIME +#define TVM_CUDA_RUNTIME 1 +#endif + +/*! + *\brief whether to use opencl runtime + */ +#ifndef TVM_OPENCL_RUNTIME +#define TVM_OPENCL_RUNTIME 0 +#endif + using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; diff --git a/include/tvm/c_runtime_api.h b/include/tvm/c_runtime_api.h index f99d83e60f2a..7be198b122ea 100644 --- a/include/tvm/c_runtime_api.h +++ b/include/tvm/c_runtime_api.h @@ -34,7 +34,7 @@ TVM_EXTERN_C { /*! \brief type of array index. */ -typedef unsigned tvm_index_t; +typedef uint32_t tvm_index_t; /*! * \brief union type for arguments and return values @@ -68,7 +68,7 @@ typedef enum { /*! \brief NVidia GPU device(CUDA) */ kGPU = 2, /*! \brief opencl device */ - KOpenCL = 4 + kOpenCL = 4 } TVMDeviceMask; /*! @@ -79,7 +79,7 @@ typedef struct { int dev_mask; /*! \brief the device id */ int dev_id; -} TVMDevice; +} TVMContext; /*! \brief The type code in TVMDataType */ typedef enum { @@ -122,8 +122,8 @@ typedef struct { tvm_index_t ndim; /*! \brief The data type flag */ TVMDataType dtype; - /*! \brief The device this array sits on */ - TVMDevice device; + /*! \brief The device context this array sits on */ + TVMContext ctx; } TVMArray; /*! @@ -150,6 +150,16 @@ typedef TVMArray* TVMArrayHandle; */ TVM_DLL const char *TVMGetLastError(void); +/*! + * \brief Whether the specified context is enabled. + * + * \param ctx The context to be checked. + * \param out_enabled whether the ctx is enabled. + * \return Whether the function is successful. + */ +TVM_DLL int TVMContextEnabled(TVMContext ctx, + int* out_enabled); + /*! * \brief Allocate a nd-array's memory, * including space of shape, of given spec. @@ -157,14 +167,14 @@ TVM_DLL const char *TVMGetLastError(void); * \param shape The shape of the array, the data content will be copied to out * \param ndim The number of dimension of the array. * \param dtype The array data type. - * \param device The device this array sits on. + * \param ctx The ctx this array sits on. * \param out The output handle. * \return Whether the function is successful. */ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, tvm_index_t ndim, - int dtype, - TVMDevice device, + TVMDataType dtype, + TVMContext ctx, TVMArrayHandle* out); /*! * \brief Free the TVM Array. @@ -183,9 +193,10 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMStreamHandle stream); /*! * \brief Wait until all computations on stream completes. - * \param stream the stream to be synchronized. + * \param ctx The ctx to be synchronized. + * \param stream The stream to be synchronized. */ -TVM_DLL int TVMSynchronize(TVMStreamHandle stream); +TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); /*! * \brief Launch a generated TVM function diff --git a/make/config.mk b/make/config.mk new file mode 100644 index 000000000000..955bc6c8cb3c --- /dev/null +++ b/make/config.mk @@ -0,0 +1,46 @@ +#------------------------------------------------------------------------------- +# Template configuration for compiling +# +# If you want to change the configuration, please use the following +# steps. Assume you are on the root directory. First copy the this +# file so that any local changes will be ignored by git +# +# $ cp make/config.mk . +# +# Next modify the according entries, and then compile by +# +# $ make +# +# or build in parallel with 8 threads +# +# $ make -j8 +#------------------------------------------------------------------------------- + +#--------------------- +# choice of compiler +#-------------------- +export NVCC = nvcc + +# whether compile with debug +DEBUG = 0 + +# the additional link flags you want to add +ADD_LDFLAGS = + +# the additional compile flags you want to add +ADD_CFLAGS = + +#--------------------------------------------- +# matrix computation libraries for CPU/GPU +#--------------------------------------------- + +# whether use CUDA during compile +USE_CUDA = 1 + +# add the path to CUDA library to link and compile flag +# if you have already add them to environment variable, leave it as NONE +# USE_CUDA_PATH = /usr/local/cuda +USE_CUDA_PATH = NONE + +# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) +USE_NVRTC = 0 diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 3725cd62c15d..00729bcf6a85 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs from ._ctypes._api import register_node -from . import tensor as tensor +from . import tensor from . import expr from . import stmt from . import make @@ -11,5 +11,8 @@ from . import collections from . import schedule +from . import ndarray as nd +from .ndarray import cpu, gpu, opencl + from ._base import TVMError from .function import * diff --git a/python/tvm/_base.py b/python/tvm/_base.py index cfa85a55ac1c..33c9848f2d1f 100644 --- a/python/tvm/_base.py +++ b/python/tvm/_base.py @@ -58,6 +58,7 @@ def check_call(ret): if ret != 0: raise TVMError(py_str(_LIB.TVMGetLastError())) + def c_str(string): """Create ctypes char * from a python string Parameters @@ -72,6 +73,26 @@ def c_str(string): """ return ctypes.c_char_p(string.encode('utf-8')) + +def c_array(ctype, values): + """Create ctypes array from a python array + + Parameters + ---------- + ctype : ctypes data type + data type of the array we want to convert to + + values : tuple or list + data content + + Returns + ------- + out : ctypes array + Created ctypes array + """ + return (ctype * len(values))(*values) + + def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True): """Convert ctypes returned doc string information into parameters docstring. diff --git a/python/tvm/_ctypes/_runtime_api.py b/python/tvm/_ctypes/_runtime_api.py new file mode 100644 index 000000000000..d12f727f6f8a --- /dev/null +++ b/python/tvm/_ctypes/_runtime_api.py @@ -0,0 +1,294 @@ +# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement +# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring +"""Symbolic configuration API.""" +from __future__ import absolute_import as _abs + +import ctypes +import numpy as np + +from .._base import _LIB +from .._base import c_array +from .._base import check_call + + +tvm_index_t = ctypes.c_uint32 + + +class TVMContext(ctypes.Structure): + """TVM context strucure.""" + _fields_ = [("dev_mask", ctypes.c_int), + ("dev_id", ctypes.c_int)] + MASK2STR = { + 1 : 'cpu', + 2 : 'gpu', + 4 : 'opencl' + } + def __init__(self, dev_mask, dev_id): + super(TVMContext, self).__init__() + self.dev_mask = dev_mask + self.dev_id = dev_id + + def __repr__(self): + return "%s(%d)" % ( + TVMContext.MASK2STR[self.dev_mask], self.dev_id) + + @property + def enabled(self): + ret = ctypes.c_int() + check_call(_LIB.TVMContextEnabled(self, ctypes.byref(ret))) + return ret.value != 0 + + +def cpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(1, dev_id) + + +def gpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(2, dev_id) + + +def opencl(dev_id=0): + """Construct a OpenCL device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(4, dev_id) + + +class TVMDataType(ctypes.Structure): + """TVM datatype structure""" + _fields_ = [("type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16)] + CODE2STR = { + 0 : 'int', + 1 : 'uint', + 2 : 'float' + } + def __init__(self, type_str, lanes=1): + super(TVMDataType, self).__init__() + if isinstance(type_str, np.dtype): + type_str = str(type_str) + + if type_str.startswith("int"): + self.type_code = 0 + bits = int(type_str[3:]) + elif type_str.startswith("uint"): + self.type_code = 1 + bits = int(type_str[4:]) + elif type_str.startswith("float"): + self.type_code = 2 + bits = int(type_str[5:]) + else: + raise ValueError("Donot know how to handle type %s" % type_str) + bits = 32 if bits == 0 else bits + if (bits & (bits - 1)) != 0 or bits < 8: + raise ValueError("Donot know how to handle type %s" % type_str) + self.bits = bits + self.lanes = lanes + + def __repr__(self): + x = "%s%d" % (TVMDataType.CODE2STR[self.type_code], self.bits) + if self.lanes != 1: + x += "x%d" % self.lanes + return x + + +class TVMArray(ctypes.Structure): + """TVMArg in C API""" + _fields_ = [("data", ctypes.c_void_p), + ("shape", ctypes.POINTER(tvm_index_t)), + ("strides", ctypes.POINTER(tvm_index_t)), + ("ndim", tvm_index_t), + ("dtype", TVMDataType), + ("ctx", TVMContext)] + +TVMArrayHandle = ctypes.POINTER(TVMArray) + + +def numpyasarray(np_data): + """Return a TVMArray representation of a numpy array. + """ + data = np_data + assert data.flags['C_CONTIGUOUS'] + arr = TVMArray() + shape = c_array(tvm_index_t, data.shape) + arr.data = data.ctypes.data_as(ctypes.c_void_p) + arr.shape = shape + arr.strides = None + arr.dtype = TVMDataType(np.dtype(data.dtype).name) + arr.ndim = data.ndim + # CPU device + arr.ctx = cpu(0) + return arr, shape + + +_ndarray_cls = None + + +def empty(shape, dtype="float32", ctx=cpu(0)): + """Create an empty array given shape and device + + Parameters + ---------- + shape : tuple of int + The shape of the array + + dtype : type or str + The data type of the array. + + ctx : TVMContext + The context of the array + + Returns + ------- + arr : tvm.nd.NDArray + The array tvm supported. + """ + shape = c_array(tvm_index_t, shape) + ndim = tvm_index_t(len(shape)) + handle = TVMArrayHandle() + dtype = TVMDataType(dtype) + check_call(_LIB.TVMArrayAlloc( + shape, ndim, dtype, ctx, ctypes.byref(handle))) + return _ndarray_cls(handle) + + +def sync(ctx): + """Synchronize all the context + + Parameters + ---------- + ctx : TVMContext + The context to be synced + """ + check_call(_LIB.TVMSynchronize(ctx, None)) + + +class NDArrayBase(object): + """A simple Device/CPU Array object in runtime.""" + __slots__ = ["handle"] + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : TVMArrayHandle + the handle to the underlying C++ TVMArray + """ + self.handle = handle + + def __del__(self): + check_call(_LIB.TVMArrayFree(self.handle)) + + @property + def shape(self): + """Shape of this array""" + return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) + + @property + def dtype(self): + """Type of this array""" + return str(self.handle.contents.dtype) + + @property + def ctx(self): + """context of this array""" + return self.handle.contents.ctx + + @property + def context(self): + """context of this array""" + return self.ctx + + def __setitem__(self, in_slice, value): + """Set ndarray value""" + if (not isinstance(in_slice, slice) or + in_slice.start is not None + or in_slice.stop is not None): + raise ValueError('Array only support set from numpy array') + if isinstance(value, NDArrayBase): + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, (np.ndarray, np.generic)): + self._sync_copyfrom(value) + else: + raise TypeError('type %s not supported' % str(type(value))) + + def _sync_copyfrom(self, source_array): + """Peform an synchronize copy from the array. + + Parameters + ---------- + source_array : array_like + The data source we should like to copy from. + """ + if not isinstance(source_array, np.ndarray): + try: + source_array = np.array(source_array, dtype=self.dtype) + except: + raise TypeError('array must be an array_like data,' + + 'type %s is not supported' % str(type(source_array))) + source_array = np.ascontiguousarray(source_array, dtype=self.dtype) + if source_array.shape != self.shape: + raise ValueError('array shape do not match the shape of NDArray') + source_tvm_arr, shape = numpyasarray(source_array) + check_call(_LIB.TVMArrayCopyFromTo( + ctypes.byref(source_tvm_arr), self.handle, None)) + # de-allocate shape until now + _ = shape + + def asnumpy(self): + """Convert this array to numpy array + + Returns + ------- + np_arr : numpy.ndarray + The corresponding numpy array. + """ + np_arr = np.empty(self.shape, dtype=self.dtype) + tvm_arr, shape = numpyasarray(np_arr) + check_call(_LIB.TVMArrayCopyFromTo( + self.handle, ctypes.byref(tvm_arr), None)) + _ = shape + return np_arr + + def copyto(self, target): + """Copy array to target + + Parameters + ---------- + target : tvm.NDArray + The target array to be copied, must have same shape as this array. + """ + if isinstance(target, TVMContext): + target = empty(self.shape, self.dtype, target) + if isinstance(target, NDArrayBase): + check_call(_LIB.TVMArrayCopyFromTo( + self.handle, target.handle, None)) + else: + raise ValueError("Unsupported target type %s" % str(type(target))) + return target + + +def _init_runtime_module(ndarray_class): + global _ndarray_cls + _ndarray_cls = ndarray_class diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py new file mode 100644 index 000000000000..fc74c28fde20 --- /dev/null +++ b/python/tvm/ndarray.py @@ -0,0 +1,51 @@ +"""TVM Runtime API. + +This is a simplified runtime API for quick testing and proptyping. +""" +# pylint: disable=unused-import +from __future__ import absolute_import as _abs +import numpy as _np + +from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase +from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync +from ._ctypes._runtime_api import _init_runtime_module + + +class NDArray(NDArrayBase): + """Lightweight NDArray class of TVM runtime. + + Strictly this is only an Array Container(a buffer object) + No arthimetic operations are defined. + All operations are performed by TVM functions. + + The goal is not to re-build yet another array library. + Instead, this is a minimal data structure to demonstrate + how can we use TVM in existing project which might have their own array containers. + """ + pass + + +def array(arr, ctx=cpu(0)): + """Create an array from source arr. + + Parameters + ---------- + arr : numpy.ndarray + The array to be copied from + + ctx : TVMContext + The device context to create the array + + Returns + ------- + ret : tvm.nd.NDArray + The created array + """ + if not isinstance(arr, _np.ndarray): + arr = _np.array(arr) + ret = empty(arr.shape, arr.dtype, ctx) + ret[:] = arr + return ret + + +_init_runtime_module(NDArray) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index c5ec453b9c26..55b3fe0bd377 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -14,6 +14,6 @@ #include #include #include "./c_api_registry.h" -#include "../runtime/runtime_common.h" +#include "../runtime/runtime_base.h" #endif // TVM_C_API_C_API_COMMON_H_ diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc new file mode 100644 index 000000000000..efa355a3b3a1 --- /dev/null +++ b/src/runtime/c_runtime_api.cc @@ -0,0 +1,151 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_runtime_api.cc + * \brief Device specific implementations + */ +#include +#include +#include "./runtime_base.h" +#include "./device_api.h" + +namespace tvm { +namespace runtime { + +inline TVMArray* TVMArrayCreate_() { + TVMArray* arr = new TVMArray(); + arr->shape = nullptr; + arr->strides = nullptr; + arr->ndim = 0; + arr->data = nullptr; + return arr; +} + +inline void TVMArrayFree_(TVMArray* arr) { + if (arr != nullptr) { + // ok to delete nullptr + delete[] arr->shape; + delete[] arr->strides; + if (arr->data != nullptr) { + TVM_DEVICE_SWITCH(arr->ctx, { + FreeDataSpace(arr->ctx, arr->data); + }); + } + } + delete arr; +} + +inline void VerifyType(TVMDataType dtype) { + CHECK_GE(dtype.lanes, 1U); + if (dtype.type_code == kFloat) { + CHECK_EQ(dtype.bits % 32U, 0U); + } else { + CHECK_EQ(dtype.bits % 8U, 0U); + } + CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); +} + +inline size_t GetDataSize(TVMArray* arr) { + size_t size = 1; + for (tvm_index_t i = 0; i < arr->ndim; ++i) { + size *= arr->shape[i]; + } + size *= (arr->dtype.bits / 8) * arr->dtype.lanes; + return size; +} + +inline size_t GetDataAlignment(TVMArray* arr) { + size_t align = (arr->dtype.bits / 8) * arr->dtype.lanes; + if (align < 8) return 8; + return align; +} + +} // namespace runtime +} // namespace tvm + +using namespace tvm::runtime; + +int TVMContextEnabled(TVMContext ctx, + int* out_enabled) { + API_BEGIN(); + if (ctx.dev_mask == kGPU && TVM_CUDA_RUNTIME == 0) { + *out_enabled = 0; + } else if (ctx.dev_mask == kOpenCL && TVM_OPENCL_RUNTIME == 0) { + *out_enabled = 0; + } else { + TVM_DEVICE_SWITCH(ctx, { + *out_enabled = CheckEnabled(ctx); + }); + } + API_END(); +} + +int TVMArrayAlloc(const tvm_index_t* shape, + tvm_index_t ndim, + TVMDataType dtype, + TVMContext ctx, + TVMArrayHandle* out) { + TVMArray* arr = nullptr; + API_BEGIN(); + // shape + arr = TVMArrayCreate_(); + // ndim + arr->ndim = ndim; + // dtype + VerifyType(dtype); + arr->dtype = dtype; + tvm_index_t* shape_copy = new tvm_index_t[ndim]; + std::copy(shape, shape + ndim, shape_copy); + arr->shape = shape_copy; + // ctx + arr->ctx = ctx; + size_t size = GetDataSize(arr); + size_t alignment = GetDataAlignment(arr); + // ctx data pointer + TVM_DEVICE_SWITCH(ctx, { + arr->data = AllocDataSpace(ctx, size, alignment); + }); + *out = arr; + API_END_HANDLE_ERROR(TVMArrayFree_(arr)); +} + +int TVMArrayFree(TVMArrayHandle handle) { + API_BEGIN(); + TVMArray* arr = handle; + TVMArrayFree_(arr); + API_END(); +} + +int TVMArrayCopyFromTo(TVMArrayHandle from, + TVMArrayHandle to, + TVMStreamHandle stream) { + API_BEGIN(); + size_t from_size = GetDataSize(from); + size_t to_size = GetDataSize(to); + CHECK_EQ(from_size, to_size) + << "TVMArrayCopyFromTo: The size must exactly match"; + TVMContext ctx = from->ctx; + if (ctx.dev_mask == kCPU) { + ctx = to->ctx; + } else { + CHECK(to->ctx.dev_mask == kCPU || + to->ctx.dev_mask == from->ctx.dev_mask) + << "Can not copy across different ctx types directly"; + } + + TVM_DEVICE_SWITCH(ctx, { + CopyDataFromTo(from->data, to->data, + from_size, + from->ctx, + to->ctx, + stream); + }); + API_END(); +} + +int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { + API_BEGIN(); + TVM_DEVICE_SWITCH(ctx, { + StreamSync(ctx, stream); + }); + API_END(); +} diff --git a/src/runtime/device_api.h b/src/runtime/device_api.h new file mode 100644 index 000000000000..b74b41ae245b --- /dev/null +++ b/src/runtime/device_api.h @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file device_api.hx + * \brief Device specific API + */ +#ifndef TVM_RUNTIME_DEVICE_API_H_ +#define TVM_RUNTIME_DEVICE_API_H_ + +#include +#include + +namespace tvm { +namespace runtime { +/*! + * \brief Whether ctx is enabled. + * \param ctx The device context to perform operation. + * \tparam xpu The device mask. + */ +template +inline bool CheckEnabled(TVMContext ctx) { + return true; +} + +/*! + * \brief Allocate a data space on device. + * \param ctx The device context to perform operation. + * \param size The size of the memory + * \param alignment The alignment of the memory. + * \return The allocated device pointer + * \tparam xpu The device mask. + */ +template +inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment); + +/*! + * \brief Free a data space on device. + * \param ctx The device context to perform operation. + * \param ptr The data space. + * \tparam xpu The device mask. + */ +template +inline void FreeDataSpace(TVMContext ctx, void* ptr); + +/*! + * \brief copy data from one place to another + * \param dev The device to perform operation. + * \param from The source array. + * \param to The target array. + * \param size The size of the memory + * \param ctx_from The source context + * \param ctx_to The target context + * \tparam xpu The device mask. + */ +template +inline void CopyDataFromTo(const void* from, + void* to, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + TVMStreamHandle stream); +/*! + * \brief Synchronize the stream + * \param ctx The context to perform operation. + * \param stream The stream to be sync. + * \tparam xpu The device mask. + */ +template +inline void StreamSync(TVMContext ctx, TVMStreamHandle stream); + +// macro to run cuda related code +#if TVM_CUDA_RUNTIME +#define TVM_RUN_CUDA(OP) { const TVMDeviceMask xpu = kGPU; OP; } +#else +#define TVM_RUN_CUDA(OP) LOG(FATAL) << "CUDA is not enabled"; +#endif + +// macro to run opencl related code +#if TVM_OPENCL_RUNTIME +#define TVM_RUN_OPENCL(OP) { const TVMDeviceMask xpu = kOpenCL; OP; } +#else +#define TVM_RUN_OPENCL(OP) LOG(FATAL) << "OpenCL is not enabled"; +#endif + +// macro to switch options between devices +#define TVM_DEVICE_SWITCH(ctx, OP) \ + switch (ctx.dev_mask) { \ + case kCPU: { const TVMDeviceMask xpu = kCPU; OP; break; } \ + case kGPU: TVM_RUN_CUDA(OP); break; \ + case kOpenCL: TVM_RUN_OPENCL(OP); break; \ + default: LOG(FATAL) << "unknown device_mask " << ctx.dev_mask; \ + } + +} // namespace runtime +} // namespace tvm + +#include "./device_api_gpu.h" +#include "./device_api_cpu.h" + +#endif // TVM_RUNTIME_DEVICE_API_H_ diff --git a/src/runtime/device_api_cpu.h b/src/runtime/device_api_cpu.h new file mode 100644 index 000000000000..6fff5ded7f6e --- /dev/null +++ b/src/runtime/device_api_cpu.h @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file device_api_gpu.h + * \brief GPU specific API + */ +#ifndef TVM_RUNTIME_DEVICE_API_CPU_H_ +#define TVM_RUNTIME_DEVICE_API_CPU_H_ + +#include +#include +#include +#include "./device_api.h" + +namespace tvm { +namespace runtime { + +template<> +void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) { + void* ptr; +#if _MSC_VER + ptr = _aligned_malloc(size, alignment); + if (ptr == nullptr) throw std::bad_alloc(); +#else + int ret = posix_memalign(&ptr, alignment, size); + if (ret != 0) throw std::bad_alloc(); +#endif + return ptr; +} + +template<> +void FreeDataSpace(TVMContext ctx, void* ptr) { +#if _MSC_VER + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +template<> +void CopyDataFromTo(const void* from, + void* to, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + TVMStreamHandle stream) { + memcpy(to, from, size); +} + +template<> +void StreamSync(TVMContext ctx, TVMStreamHandle stream) { +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DEVICE_API_CPU_H_ diff --git a/src/runtime/device_api_gpu.h b/src/runtime/device_api_gpu.h new file mode 100644 index 000000000000..970450657b2e --- /dev/null +++ b/src/runtime/device_api_gpu.h @@ -0,0 +1,106 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file ctxice_api_gpu.h + * \brief GPU specific API + */ +#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_ +#define TVM_RUNTIME_DEVICE_API_GPU_H_ + +#include +#include "./device_api.h" + +#if TVM_CUDA_RUNTIME +#include + +namespace tvm { +namespace runtime { +/*! + * \brief Check CUDA error. + * \param msg Message to print if an error occured. + */ +#define CHECK_CUDA_ERROR(msg) \ + { \ + cudaError_t e = cudaGetLastError(); \ + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ + } + +/*! + * \brief Protected CUDA call. + * \param func Expression to call. + * + * It checks for CUDA errors after invocation of the expression. + */ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + +template<> +inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) { + CUDA_CALL(cudaSetDevice(ctx.dev_id)); + CHECK_EQ(256 % alignment, 0U) + << "CUDA space is aligned at 256 bytes"; + void *ret; + CUDA_CALL(cudaMalloc(&ret, size)); + return ret; +} + +template<> +inline void FreeDataSpace(TVMContext ctx, void* ptr) { + CUDA_CALL(cudaSetDevice(ctx.dev_id)); + CUDA_CALL(cudaFree(ptr)); +} + +inline void GPUCopy(const void* from, + void* to, + size_t size, + cudaMemcpyKind kind, + cudaStream_t stream) { + if (stream != 0) { + CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); + } else { + CUDA_CALL(cudaMemcpy(to, from, size, kind)); + } +} + +template<> +inline void CopyDataFromTo(const void* from, + void* to, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + TVMStreamHandle stream) { + cudaStream_t cu_stream = static_cast(stream); + if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kGPU) { + CUDA_CALL(cudaSetDevice(ctx_from.dev_id)); + if (ctx_from.dev_id == ctx_to.dev_id) { + GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); + } else { + cudaMemcpyPeerAsync(to, ctx_to.dev_id, + from, ctx_from.dev_id, + size, cu_stream); + } + } else if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kCPU) { + CUDA_CALL(cudaSetDevice(ctx_from.dev_id)); + GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); + } else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kGPU) { + CUDA_CALL(cudaSetDevice(ctx_to.dev_id)); + GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); + } else { + LOG(FATAL) << "expect copy from/to GPU or between GPU"; + } +} + +template<> +inline void StreamSync(TVMContext ctx, TVMStreamHandle stream) { + CUDA_CALL(cudaSetDevice(ctx.dev_id)); + CUDA_CALL(cudaStreamSynchronize( + static_cast(stream))); +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_CUDA_RUNTIME +#endif // TVM_RUNTIME_DEVICE_API_GPU_H_ diff --git a/src/runtime/error_handle.cc b/src/runtime/error_handle.cc index dad261e8e607..2a6c0bf75a7c 100644 --- a/src/runtime/error_handle.cc +++ b/src/runtime/error_handle.cc @@ -5,7 +5,7 @@ */ #include #include -#include "./runtime_common.h" +#include "./runtime_base.h" struct TVMErrorEntry { std::string last_error; diff --git a/src/runtime/runtime_common.h b/src/runtime/runtime_base.h similarity index 83% rename from src/runtime/runtime_common.h rename to src/runtime/runtime_base.h index 998fd39857c9..1a3233342bac 100644 --- a/src/runtime/runtime_common.h +++ b/src/runtime/runtime_base.h @@ -1,13 +1,13 @@ /*! * Copyright (c) 2016 by Contributors - * \file c_runtime_common.h - * \brief Common fields of all C APIs + * \file runtime_base.h + * \brief Base of all C APIs */ -#ifndef TVM_RUNTIME_RUNTIME_COMMON_H_ -#define TVM_RUNTIME_RUNTIME_COMMON_H_ +#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ +#define TVM_RUNTIME_RUNTIME_BASE_H_ #include -#include +#include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { @@ -33,4 +33,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) { return -1; } -#endif // TVM_RUNTIME_RUNTIME_COMMON_H_ +#endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/tests/python/test_runtime_ndarray.py b/tests/python/test_runtime_ndarray.py new file mode 100644 index 000000000000..602343b3efef --- /dev/null +++ b/tests/python/test_runtime_ndarray.py @@ -0,0 +1,29 @@ +import tvm +import numpy as np + +def enabled_ctx_list(): + ctx_list = [tvm.cpu(0), tvm.gpu(0), tvm.opencl(0)] + ctx_list = [ctx for ctx in ctx_list if ctx.enabled] + return ctx_list + +ENABLED_CTX_LIST = enabled_ctx_list() +print("Testing using contexts:", ENABLED_CTX_LIST) + + +def test_nd_create(): + for ctx in ENABLED_CTX_LIST: + for dtype in ["float32", "int8", "uint16"]: + x = np.random.randint(0, 10, size=(3, 4)) + x = np.array(x, dtype=dtype) + y = tvm.nd.array(x, ctx=ctx) + z = y.copyto(ctx) + assert y.dtype == x.dtype + assert y.shape == x.shape + assert isinstance(y, tvm.nd.NDArray) + np.testing.assert_equal(x, y.asnumpy()) + np.testing.assert_equal(x, z.asnumpy()) + # no need here, just to test usablity + tvm.nd.sync(ctx) + +if __name__ == "__main__": + test_nd_create() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index d682e2b0fb94..02cd71308f16 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -14,6 +14,9 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then fi fi +cp make/config.mk config.mk +echo "USE_CUDA=0" >> config.mk +echo "USE_OPENCL=0" >> config.mk if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then # use g++-4.8 for linux