diff --git a/.travis.yml b/.travis.yml index 31d6e49f3dd1..e7110ecbacac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ language: cpp os: - linux - # - osx + - osx env: # code analysis diff --git a/HalideIR b/HalideIR index 30bf0f043e63..642ae50ac749 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 30bf0f043e6388418958fd1f29259ee43c42b600 +Subproject commit 642ae50ac749c91c04483db04500163304d4334e diff --git a/Makefile b/Makefile index 0cee2e36ed15..97c0e1ed3d86 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ all: lib/libtvm.a lib/libtvm.so LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a -SRC = $(wildcard src/*.cc src/*/*.cc) +SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) @@ -39,7 +39,7 @@ endif ifeq ($(USE_CUDA), 1) CFLAGS += -DTVM_CUDA_RUNTIME=1 - LDFLAGS += -lcuda -lcudart + LDFLAGS += -lcuda -lcudart -lnvrtc else CFLAGS += -DTVM_CUDA_RUNTIME=0 endif @@ -92,3 +92,4 @@ clean: -include build/*.d -include build/*/*.d +-include build/*/*/*.d diff --git a/include/tvm/base.h b/include/tvm/base.h index c9c95f3c9bc2..726494ca4670 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -12,25 +12,9 @@ #include #include #include -#include -#include namespace tvm { -/*! - *\brief whether to use CUDA runtime - */ -#ifndef TVM_CUDA_RUNTIME -#define TVM_CUDA_RUNTIME 1 -#endif - -/*! - *\brief whether to use opencl runtime - */ -#ifndef TVM_OPENCL_RUNTIME -#define TVM_OPENCL_RUNTIME 0 -#endif - using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index ddafe500482f..4d76a88a7265 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -9,13 +9,18 @@ #include #include "./base.h" #include "./expr.h" -#include "./module.h" +#include "./lowered_func.h" #include "./runtime/packed_func.h" namespace tvm { /*! \brief namespace for lowlevel IR pass and codegen */ namespace codegen { +// use packed function from runtime. +using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; + /*! * \brief Make an user callable API LoweredFunc. * @@ -64,8 +69,35 @@ Array UndefinedVars(const LoweredFunc& f); */ Array SplitHostDevice(LoweredFunc func); +/*! + * \brief Build a stack VM function. + * \param func The LoweredFunc to be build + * \param device_funcs The additional device functions + * \return A packed function representing the func. + */ +PackedFunc BuildStackVM( + LoweredFunc func, + const std::unordered_map& device_funcs); + +/*! + * \brief Build a CUDA function with NVRTC + * + * \param fsplits The LoweredFuncs to be build (after SplitHostDevice) + * The first element is the host function, followed by device functions. + * \param host_mode The host side compilation mode: + * - "stackvm": use stack vm to interpret host side code. + */ +PackedFunc BuildNVRTC(Array fsplits, std::string host_mode); -runtime::PackedFunc BuildStackVM(LoweredFunc func); +/*! + * \brief Build a OpenCL function. + * + * \param fsplits The LoweredFuncs to be build (after SplitHostDevice) + * The first element is the host function, followed by device functions. + * \param host_mode The host side compilation mode: + * - "stackvm": use stack vm to interpret host side code. + */ +PackedFunc BuildOpenCL(Array fsplits, std::string host_mode); } // namespace codegen } // namespace tvm diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 67b5dcdd4571..067c8dff3b14 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -12,7 +12,7 @@ #include #include #include "./base.h" -#include "./runtime/packed_func.h" +#include "./runtime/c_runtime_api.h" namespace tvm { @@ -33,6 +33,19 @@ using Halide::Internal::Variable; using Halide::Internal::make_const; + +inline Type TVMType2Type(TVMType t) { + return Type(static_cast(t.code), t.bits, t.lanes); +} + +inline TVMType Type2TVMType(Type t) { + TVMType ret; + ret.code = static_cast(t.code()); + ret.bits = static_cast(t.bits()); + ret.lanes = static_cast(t.lanes()); + return ret; +} + /*! \brief a named variable in TVM */ class Var : public Halide::VarExpr { public: diff --git a/include/tvm/module.h b/include/tvm/lowered_func.h similarity index 87% rename from include/tvm/module.h rename to include/tvm/lowered_func.h index 34ce7438b1a5..d9c685d77957 100644 --- a/include/tvm/module.h +++ b/include/tvm/lowered_func.h @@ -1,11 +1,11 @@ /*! - * Copyright (c) 2016 by Contributors - * \file module.h - * \brief Low level IR module, - * Contains lowered function information. + * Copyright (c) 2017 by Contributors + * \file lowered_func.h + * \brief Information about a lowered TVM function. + * This data structure is final step toward codegen. */ -#ifndef TVM_MODULE_H_ -#define TVM_MODULE_H_ +#ifndef TVM_LOWERED_FUNC_H_ +#define TVM_LOWERED_FUNC_H_ #include #include @@ -102,4 +102,13 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const { } // namespace tvm -#endif // TVM_MODULE_H_ +namespace std { +template <> +struct hash<::tvm::LoweredFunc> { + std::size_t operator()(const ::tvm::LoweredFunc& k) const { + return k.hash(); + } +}; +} + +#endif // TVM_LOWERED_FUNC_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index edaf43a806ec..94cfff26f1d0 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -14,6 +14,7 @@ #include "./base.h" #include "./expr.h" +#include "./runtime/packed_func.h" namespace tvm { using runtime::TVMArgs; @@ -162,19 +163,7 @@ inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLI type_codes_[i] = kNodeHandle; } -// Type related stuffs -inline Type TVMType2Type(TVMType t) { - return Type(static_cast(t.code), t.bits, t.lanes); -} - -inline TVMType Type2TVMType(Type t) { - TVMType ret; - ret.code = static_cast(t.code()); - ret.bits = static_cast(t.bits()); - ret.lanes = static_cast(t.lanes()); - return ret; -} - +// type related stuffs inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) { return this->operator=(Type2TVMType(t)); } diff --git a/include/tvm/runtime/config.h b/include/tvm/runtime/config.h new file mode 100644 index 000000000000..92a737f825bb --- /dev/null +++ b/include/tvm/runtime/config.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file config.h + * \brief Runtime library related configurations. + */ +#ifndef TVM_RUNTIME_CONFIG_H_ +#define TVM_RUNTIME_CONFIG_H_ + +/*! + *\brief whether to use CUDA runtime + */ +#ifndef TVM_CUDA_RUNTIME +#define TVM_CUDA_RUNTIME 1 +#endif + +/*! + *\brief whether to use opencl runtime + */ +#ifndef TVM_OPENCL_RUNTIME +#define TVM_OPENCL_RUNTIME 0 +#endif + +#endif // TVM_RUNTIME_CONFIG_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index e3a391c84b4e..eafc367fe3c5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -163,6 +163,13 @@ inline const char* TypeCode2Str(int type_code); */ inline TVMType String2TVMType(std::string s); +/*! + * \brief convert a TVM type to string. + * \param t The type to be converted. + * \return The corresponding tvm type in string. + */ +inline std::string TVMType2String(TVMType t); + // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ @@ -258,6 +265,9 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator TVMArray*; // conversion operator. operator std::string() const { + if (type_code_ == kTVMType) { + return TVMType2String(operator TVMType()); + } TVM_CHECK_TYPE_CODE(type_code_, kStr); return std::string(value_.v_str); } @@ -308,7 +318,6 @@ class TVMRetValue : public TVMPODValue_ { */ TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { - other.type_code_ = kNull; } /*! \brief destructor */ ~TVMRetValue() { @@ -328,6 +337,9 @@ class TVMRetValue : public TVMPODValue_ { } // conversion operators operator std::string() const { + if (type_code_ == kTVMType) { + return TVMType2String(operator TVMType()); + } TVM_CHECK_TYPE_CODE(type_code_, kStr); return *ptr(); } @@ -418,6 +430,13 @@ class TVMRetValue : public TVMPODValue_ { *ret_type_code = type_code_; type_code_ = kNull; } + /*! \return The value field, if the data is POD */ + const TVMValue& value() const { + CHECK(type_code_ != kNodeHandle && + type_code_ != kFuncHandle && + type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; + return value_; + } // NodeRef related extenstions: in tvm/packed_func_ext.h inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const std::shared_ptr& other); @@ -488,7 +507,7 @@ inline const char* TypeCode2Str(int type_code) { case kInt: return "int"; case kFloat: return "float"; case kStr: return "str"; - case kHandle: return "Handle"; + case kHandle: return "handle"; case kNull: return "NULL"; case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; @@ -499,6 +518,21 @@ inline const char* TypeCode2Str(int type_code) { } } +inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) + os << TypeCode2Str(t.code) + << static_cast(t.bits); + if (t.lanes != 1) { + os << 'x' << static_cast(t.lanes); + } + return os; +} + +inline std::string TVMType2String(TVMType t) { + std::ostringstream os; + os << t; + return os.str(); +} + inline TVMType String2TVMType(std::string s) { TVMType t; t.bits = 32; t.lanes = 1; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 2bd6359b65e9..a1e2d4ff483b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -13,7 +13,7 @@ from . import schedule from . import ndarray as nd -from .ndarray import cpu, gpu, opencl, init_opencl +from .ndarray import cpu, gpu, opencl, init_opencl, cl from ._base import TVMError from .api import * diff --git a/python/tvm/_ctypes/_function.py b/python/tvm/_ctypes/_function.py index 3b133e552cd3..e8377fe4bdfe 100644 --- a/python/tvm/_ctypes/_function.py +++ b/python/tvm/_ctypes/_function.py @@ -90,8 +90,8 @@ def _make_tvm_args(args, temp_args): values[i].v_float64 = arg type_codes[i] = TypeCode.FLOAT elif isinstance(arg, TVMType): - values[i].v_type = arg - type_codes[i] = TypeCode.TVM_TYPE + values[i].v_str = c_str(str(arg)) + type_codes[i] = TypeCode.STR elif isinstance(arg, string_types): values[i].v_str = c_str(arg) type_codes[i] = TypeCode.STR diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py index f6d0f5e71a8f..58bdf111e54c 100644 --- a/python/tvm/_ctypes/_types.py +++ b/python/tvm/_ctypes/_types.py @@ -86,8 +86,7 @@ class TVMValue(ctypes.Union): _fields_ = [("v_int64", ctypes.c_int64), ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), - ("v_str", ctypes.c_char_p), - ("v_type", TVMType)] + ("v_str", ctypes.c_char_p)] TVMPackedCFunc = ctypes.CFUNCTYPE( @@ -117,7 +116,6 @@ def _return_handle(x): TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.HANDLE: _return_handle, TypeCode.NULL: lambda x: None, - TypeCode.TVM_TYPE: lambda x: x.v_type, TypeCode.STR: lambda x: py_str(x.v_str) } @@ -127,6 +125,5 @@ def _return_handle(x): TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.HANDLE: _return_handle, TypeCode.NULL: lambda x: None, - TypeCode.TVM_TYPE: lambda x: x.v_type, TypeCode.STR: lambda x: py_str(x.v_str) } diff --git a/python/tvm/api.py b/python/tvm/api.py index 344ea1073e29..7c7a5b33c4f9 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -16,9 +16,9 @@ from . import expr as _expr from . import collections as _collections -int32 = TVMType("int32") -float32 = TVMType("float32") -handle = TVMType("handle") +int32 = "int32" +float32 = "float32" +handle = "handle" def const(value, dtype=None): """construct a constant""" diff --git a/python/tvm/collections.py b/python/tvm/collections.py index b3bff80be478..c24b1d81b58f 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -9,6 +9,12 @@ class Array(NodeBase): """Array container of TVM""" def __getitem__(self, i): + if isinstance(i, slice): + start = i.start if i.start is not None else 0 + stop = i.stop if i.stop is not None else len(self) + step = i.step if i.step is not None else 1 + return [self[idx] for idx in range(start, stop, step)] + if i >= len(self): raise IndexError("array index out ot range") return _api_internal._ArrayGetItem(self, i) diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 8dbc22fd7721..2f2492eeafbc 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -2,7 +2,7 @@ This is a simplified runtime API for quick testing and proptyping. """ -# pylint: disable=unused-import +# pylint: disable=unused-import, invalid-name from __future__ import absolute_import as _abs import numpy as _np @@ -12,6 +12,8 @@ from ._ctypes._ndarray import init_opencl from ._ctypes._function import Function +cl = opencl + class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/src/README.md b/src/README.md index 652de2778998..16dfc19d8f54 100644 --- a/src/README.md +++ b/src/README.md @@ -1,8 +1,7 @@ # Code organization -- c_api C API related functions +- api API functionr registration - lang The definition of DSL related data structure - schedule The operations on the schedule graph before converting to IR. - pass The optimization pass on the IR structure - runtime Minimum runtime related codes. -- jit JIT runtime related code. diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 5f2c958c6a5e..1cb5a2ad0088 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -8,15 +8,30 @@ #include #include #include "../codegen/codegen_c.h" +#include "../codegen/codegen_cuda.h" +#include "../codegen/codegen_opencl.h" namespace tvm { namespace codegen { TVM_REGISTER_API(_codegen_CompileToC) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CodeGenC().Compile(args[0], args[1]); + std::string mode = "c"; + if (args.size() > 2) { + mode = args[2].operator std::string(); + } + if (mode == "c") { + *ret = CodeGenC().Compile(args[0], args[1]); + } else if (mode == "cuda") { + *ret = CodeGenCUDA().Compile(args[0], args[1]); + } else if (mode == "opencl") { + *ret = CodeGenOpenCL().Compile(args[0], args[1]); + } else { + LOG(FATAL) << "cannot recognize mode"; + } }); + TVM_REGISTER_API(_codegen_MakeAPI) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = MakeAPI( @@ -28,29 +43,20 @@ TVM_REGISTER_API(_codegen_SplitHostDevice) *ret = SplitHostDevice(args[0]); }); -// generate a dummy packed function for testing -void DummyHelloFunction(TVMArgs args, TVMRetValue* rv) { - LOG(INFO) << args.size() << " arguments"; - for (int i = 0; i < args.size(); ++i) { - switch (args.type_codes[i]) { - case kNull: LOG(INFO) << i << ":nullptr"; break; - case kFloat: LOG(INFO) << i << ": double=" << args.values[i].v_float64; break; - case kInt: LOG(INFO) << i << ": long=" << args.values[i].v_int64; break; - case kHandle: LOG(INFO) << i << ": handle=" << args.values[i].v_handle; break; - case kArrayHandle: LOG(INFO) << i << ": array_handle=" << args.values[i].v_handle; break; - default: LOG(FATAL) << "unhandled type " << runtime::TypeCode2Str(args.type_codes[i]); - } - } -} +TVM_REGISTER_API(_codegen_BuildStackVM) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BuildStackVM(args[0], + std::unordered_map()); + }); -TVM_REGISTER_API(_codegen_DummyHelloFunction) +TVM_REGISTER_API(_codegen_BuildNVRTC) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = runtime::PackedFunc(DummyHelloFunction); + *ret = BuildNVRTC(args[0], args[1]); }); -TVM_REGISTER_API(_codegen_BuildStackVM) +TVM_REGISTER_API(_codegen_BuildOpenCL) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = BuildStackVM(args[0]); + *ret = BuildOpenCL(args[0], args[1]); }); } // namespace codegen diff --git a/src/api/c_api.cc b/src/api/c_api.cc index c4290c57160a..712ed527fd13 100644 --- a/src/api/c_api.cc +++ b/src/api/c_api.cc @@ -121,7 +121,8 @@ int TVMNodeGetAttr(NodeHandle handle, } else { (*tnode)->VisitAttrs(&getter); *ret_success = getter.found_node_ref || rv.type_code() != kNull; - if (rv.type_code() == kStr) { + if (rv.type_code() == kStr || + rv.type_code() == kTVMType) { TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); e->ret_str = rv.operator std::string(); *ret_type_code = kStr; diff --git a/src/base/common.h b/src/base/common.h deleted file mode 100644 index 4b1c799e6fa3..000000000000 --- a/src/base/common.h +++ /dev/null @@ -1,119 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file common.h - * \brief Common utilities - */ -#ifndef TVM_BASE_COMMON_H_ -#define TVM_BASE_COMMON_H_ - -#include -#include -#include - -namespace tvm { - -inline std::string Type2String(const Type& t) { - if (t.code() ==Type::Handle) return "handle"; - std::ostringstream os; - os << t; - return os.str(); -} - -inline Type String2Type(std::string s) { - std::istringstream is(s); - halide_type_code_t code = Type::Int; - if (s.substr(0, 3) == "int") { - code = Type::Int; s = s.substr(3); - } else if (s.substr(0, 4) == "uint") { - code = Type::UInt; s = s.substr(4); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else if (s == "handle") { - return Handle(); - } else { - LOG(FATAL) << "unknown type " << s; - } - int bits = 32, lanes = 1; - if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { - LOG(FATAL) << "unknown type " << s; - } - return Type(code, bits, lanes); -} - -inline const char* TVMTypeCode2Str(int type_code) { - switch (type_code) { - case kInt: return "int"; - case kFloat: return "float"; - case kStr: return "str"; - case kHandle: return "Handle"; - case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; - } -} -template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI - using ContainerType = typename T::ContainerType; - return (dynamic_cast(sptr) != nullptr); - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - using ContainerType = typename T::ContainerType; - os << ContainerType::_type_key; - } -}; - -template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); - for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) return false; - } - return true; - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; - } -}; - -template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); - for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); - os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; - } -}; - -template -inline std::string NodeTypeName() { - std::ostringstream os; - NodeTypeChecker::PrintName(os); - return os.str(); -} - -} // namespace tvm -#endif // TVM_BASE_COMMON_H_ diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 327778db05bb..eade8577f2ea 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -25,7 +25,15 @@ std::string CodeGenC::Compile(LoweredFunc f, Var v = f->args[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; - PrintType(v.type(), stream); + if (v.type().is_handle()) { + stream << arg_addr_space_; + } + if (handle_data_type_.count(v.get())) { + PrintType(handle_data_type_.at(v.get()), stream); + stream << "*"; + } else { + PrintType(v.type(), stream); + } stream << ' ' << vid; } stream << ") {\n"; @@ -510,6 +518,10 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) .set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }); +void CodeGenC::PrintThreadTagExpr( + std::string thread_tag, std::ostream& os) const { // NOLINT(*) + os << thread_tag; +} void CodeGenC::PrintStmt(const LetStmt* op) { std::string value = PrintExpr(op->value); @@ -585,7 +597,9 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { PrintType(iv->var.type(), stream); stream << ' ' << AllocVarID(iv->var.get()) - << " = " << iv->thread_tag << ";\n"; + << " = "; + PrintThreadTagExpr(iv->thread_tag, stream); + stream << ";\n"; } } } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index d9392d232657..30ae1d6c46bb 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -7,7 +7,8 @@ #define TVM_CODEGEN_CODEGEN_C_H_ #include -#include +#include +#include #include #include @@ -70,13 +71,20 @@ class CodeGenC { * \return the variable name. */ std::string GetVarID(const Variable* v) const; + // The following parts are overloadable print operations. /*! * Print Type represetnation of type t. * \param t The type representation. - * \return os The stream to print the ctype into + * \param os The stream to print the ctype into */ virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*) - // The following parts are overloadable print operations. + /*! + * \brief Print expr representing the thread tag + * \param thread_tag The tag in the thread. + * \param os The strean to output to + */ + virtual void PrintThreadTagExpr( + std::string thread_tag, std::ostream& os) const; // NOLINT(*) virtual void PrintStmt(const ir::LetStmt* op); virtual void PrintStmt(const ir::Store* op); virtual void PrintStmt(const ir::Allocate* op); @@ -101,6 +109,10 @@ class CodeGenC { /*! \brief the stream to be printed */ std::ostringstream stream; + protected: + // additional string for arg addr_space. + std::string arg_addr_space_; + private: /*! * \brief Get the SSA ID corresponds to src diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc new file mode 100644 index 000000000000..a9b69ed9e491 --- /dev/null +++ b/src/codegen/codegen_cuda.cc @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_cuda.cc + */ +#include +#include +#include +#include +#include +#include "./codegen_cuda.h" +#include "./codegen_stack_vm.h" +#include "../runtime/cuda/cuda_common.h" +#include "../runtime/cuda/cuda_module.h" + +namespace tvm { +namespace codegen { + +std::string CodeGenCUDA::Compile( + LoweredFunc f, + bool output_ssa) { + this->stream << "extern \"C\" __global__ "; + return CodeGenC::Compile(f, output_ssa); +} + +#if TVM_CUDA_RUNTIME +std::unordered_map +MakeNVRTC(Array funcs) { + std::ostringstream os; + os << "typedef int int32_t;\n" + << "typedef unsigned unt32_t;\n"; + bool output_ssa = true; + for (LoweredFunc f : funcs) { + os << CodeGenCUDA().Compile(f, output_ssa); + os << '\n'; + } + std::string ptx = runtime::NVRTCCompile(os.str()); + std::unordered_map ret; + + runtime::CUDAModule m = runtime::CUDAModule::Create(ptx); + for (LoweredFunc f : funcs) { + std::vector arg_types(f->args.size()); + std::vector thread_axis_tags(f->thread_axis.size()); + + for (size_t i = 0; i < f->args.size(); ++i) { + arg_types[i] = Type2TVMType(f->args[i].type()); + } + for (size_t i = 0; i < f->thread_axis.size(); ++i) { + thread_axis_tags[i] = f->thread_axis[i]->thread_tag; + } + ret[f] = m.GetPackedFunc(f->name, arg_types, thread_axis_tags); + } + + return ret; +} + +PackedFunc BuildNVRTC(Array fsplits, std::string host_mode) { + Array device_list(fsplits.begin() + 1, fsplits.end()); + std::unordered_map device_funcs = MakeNVRTC(device_list); + + if (host_mode == "stackvm") { + StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs); + auto f = [vm](TVMArgs args, TVMRetValue* rv) { + runtime::AutoSetCUDADevice(args); + vm(args); + }; + return PackedFunc(f); + } else { + LOG(FATAL) << "unknown host mode " << host_mode; + return PackedFunc(); + } +} +#else +// dummy function when cuda is not available +PackedFunc BuildNVRTC(Array func, std::string host_mode) { + LOG(FATAL) << "CUDA is not enabled"; + return PackedFunc(); +} +#endif // TVM_CUDA_RUNTIME +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h new file mode 100644 index 000000000000..b8e55b80f4b2 --- /dev/null +++ b/src/codegen/codegen_cuda.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_cuda.h + * \brief Utility to generate cuda code + */ +#ifndef TVM_CODEGEN_CODEGEN_CUDA_H_ +#define TVM_CODEGEN_CODEGEN_CUDA_H_ + +#include +#include +#include +#include "./codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenCUDA : public CodeGenC { + public: + /*! + * \brief Generate the C code of statement + * \param f The function to be compiled + * \param output_ssa Whether output ssa form. + * \note Only call compile once, + * create a new codegen object each time. + */ + std::string Compile(LoweredFunc f, + bool output_ssa); +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_CODEGEN_CODEGEN_CUDA_H_ diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc new file mode 100644 index 000000000000..54b9b849461a --- /dev/null +++ b/src/codegen/codegen_opencl.cc @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_cuda.cc + */ +#include +#include +#include +#include +#include "./codegen_opencl.h" +#include "./codegen_stack_vm.h" +#include "../runtime/opencl/opencl_common.h" +#include "../runtime/opencl/opencl_module.h" + +namespace tvm { +namespace codegen { + +std::string CodeGenOpenCL::Compile( + LoweredFunc f, + bool output_ssa) { + this->stream << " __kernel "; + this->arg_addr_space_ = "__global "; + return CodeGenC::Compile(f, output_ssa); +} + +void CodeGenOpenCL::PrintThreadTagExpr( + std::string thread_tag, std::ostream& os) const { // NOLINT(*) + if (thread_tag == "threadIdx.x") { + os << "get_local_id(0)"; + } else if (thread_tag == "threadIdx.y") { + os << "get_local_id(1)"; + } else if (thread_tag == "threadIdx.z") { + os << "get_local_id(2)"; + } else if (thread_tag == "blockIdx.x") { + os << "get_global_id(0) / get_local_size(0)"; + } else if (thread_tag == "blockIdx.y") { + os << "get_global_id(1) / get_local_size(1)"; + } else if (thread_tag == "blockIdx.z") { + os << "get_global_id(2) / get_local_size(2)"; + } else { + LOG(FATAL) << "unknown thread tag"; + } +} + +#if TVM_OPENCL_RUNTIME +std::unordered_map +MakeOpenCL(Array funcs) { + std::ostringstream os; + os << "typedef int int32_t;\n" + << "typedef unsigned unt32_t;\n"; + bool output_ssa = true; + for (LoweredFunc f : funcs) { + os << CodeGenOpenCL().Compile(f, output_ssa); + os << '\n'; + } + std::unordered_map ret; + runtime::OpenCLModule m = + runtime::OpenCLModule::CreateWithSource(os.str()); + for (LoweredFunc f : funcs) { + std::vector arg_types(f->args.size()); + std::vector thread_axis_tags(f->thread_axis.size()); + + for (size_t i = 0; i < f->args.size(); ++i) { + arg_types[i] = Type2TVMType(f->args[i].type()); + } + for (size_t i = 0; i < f->thread_axis.size(); ++i) { + thread_axis_tags[i] = f->thread_axis[i]->thread_tag; + } + ret[f] = m.GetPackedFunc(f->name, arg_types, thread_axis_tags); + } + return ret; +} + +PackedFunc BuildOpenCL(Array fsplits, std::string host_mode) { + Array device_list(fsplits.begin() + 1, fsplits.end()); + std::unordered_map device_funcs = MakeOpenCL(device_list); + + if (host_mode == "stackvm") { + StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs); + auto f = [vm](TVMArgs args, TVMRetValue* rv) { + runtime::AutoSetOpenCLContext(args); + vm(args); + }; + return PackedFunc(f); + } else { + LOG(FATAL) << "unknown host mode " << host_mode; + return PackedFunc(); + } +} +#else +// dummy function when opencl is not available +PackedFunc BuildOpenCL(Array func, std::string host_mode) { + LOG(FATAL) << "OpenCL is not enabled"; + return PackedFunc(); +} +#endif // TVM_OPENCL_RUNTIME +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h new file mode 100644 index 000000000000..748599708752 --- /dev/null +++ b/src/codegen/codegen_opencl.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_opencl.h + * \brief Utility to generate opencl code + */ +#ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_ +#define TVM_CODEGEN_CODEGEN_OPENCL_H_ + +#include +#include +#include +#include "./codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenOpenCL : public CodeGenC { + public: + /*! + * \brief Generate the OpenCL code of statement + * \param f The function to be compiled + * \param output_ssa Whether output ssa form. + * \note Only call compile once, + * create a new codegen object each time. + */ + std::string Compile(LoweredFunc f, + bool output_ssa); + // override print thread tag. + void PrintThreadTagExpr( + std::string thread_tag, std::ostream& os) const final; // NOLINT(*) +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_CODEGEN_CODEGEN_OPENCL_H_ diff --git a/src/codegen/codegen_stack_vm.cc b/src/codegen/codegen_stack_vm.cc index 7748504f0c9c..a2fdf6235348 100644 --- a/src/codegen/codegen_stack_vm.cc +++ b/src/codegen/codegen_stack_vm.cc @@ -2,6 +2,7 @@ * Copyright (c) 2017 by Contributors * \file codegen_stack_vm.cc */ +#include #include #include "./codegen_stack_vm.h" @@ -10,55 +11,34 @@ namespace codegen { using namespace ir; -runtime::PackedFunc BuildStackVM(LoweredFunc func) { - StackVM vm = codegen::CodeGenStackVM().Compile(func); - using runtime::TVMArgs; - using runtime::TVMRetValue; - +PackedFunc BuildStackVM( + LoweredFunc func, + const std::unordered_map& device_funcs) { + StackVM vm = codegen::CodeGenStackVM().Compile(func, device_funcs); auto f = [vm](TVMArgs args, TVMRetValue* rv) { - StackVM::State* s = StackVM::ThreadLocalState(); - s->sp = 0; - s->pc = 0; - if (s->heap.size() < vm.heap_size) { - s->heap.resize(vm.heap_size); - } - s->heap[0].v_handle = (void*)args.values; // NOLINT(*) - s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) - s->heap[2].v_int64 = args.num_args; - vm.Run(s); + vm(args); }; - - return runtime::PackedFunc(f); -} - -TVMValue TVMPrint(const TVMValue* args, int num_args) { - CHECK_EQ(num_args, 2); - int tcode = static_cast(args[1].v_int64); - int code = (tcode >> (8 * 3)) & 255; - int bits = (tcode >> (8 * 2)) & 255; - int lanes = tcode & ((1 << 16) - 1); - Type t((halide_type_code_t)code, bits, lanes); - if (t.is_handle()) { - LOG(INFO) << t << ": " << args[0].v_handle; - } else if (t.is_float()) { - LOG(INFO) << t << ": " << args[0].v_float64; - } else { - LOG(INFO) << t << ": " << args[0].v_int64; - } - TVMValue r; r.v_int64 = 0; - return r; + return PackedFunc(f); } CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*) static FType inst; return inst; } -StackVM CodeGenStackVM::Compile(LoweredFunc f) { +StackVM CodeGenStackVM::Compile( + LoweredFunc f, + const std::unordered_map& device_funcs) { for (size_t i = 0; i < f->args.size(); ++i) { Var v = f->args[i]; int vid = AllocVarID(v.get()); CHECK_EQ(static_cast(vid), i); } + // setup device function map + for (const auto& kv : device_funcs) { + int fid = static_cast(vm_.packed_func.size()); + vm_.packed_func.push_back(kv.second); + device_fun_idmap_[kv.first] = fid; + } this->Push(f->body); return std::move(vm_); } @@ -117,33 +97,23 @@ int CodeGenStackVM::AllocVarID(const Variable* v) { return vid; } -int CodeGenStackVM::GetGlobalFuncID(std::string name) { - auto it = fun_idmap_.find(name); - if (it != fun_idmap_.end()) return it->second; - using runtime::PackedFunc; - using runtime::TVMArgs; - using runtime::TVMRetValue; - - PackedFunc f = PackedFunc::GetGlobal(name); - auto extern_f = [f](const TVMValue* args, int num_args) { - CHECK_EQ(num_args % 2, 0); - num_args = num_args / 2; - std::vector type_codes(std::max(num_args, 1)); - for (int i = 0; i < num_args; ++i) { - int tcode = static_cast(args[num_args + i].v_int64); - int code = (tcode >> (8 * 3)) & 255; - type_codes[i] = code; - } - TVMRetValue rv; - f.CallPacked(TVMArgs(args, &type_codes[0], num_args), &rv); - TVMValue r; r.v_int64 = 0; - return r; - }; - int fid = static_cast(vm_.extern_func.size()); - vm_.extern_func.push_back(extern_f); - fun_idmap_[name] = fid; - - return fid; +void CodeGenStackVM::PushCallPacked( + int fid, const std::vector& arg_type_codes) { + StackVM::Code code; + // CALL_PACKED_FUNC + code.op_code = StackVM::CALL_PACKED_FUNC; + vm_.code.push_back(code); + // num_args + code.v_int = static_cast(arg_type_codes.size()); + vm_.code.push_back(code); + // fid + code.v_int = fid; + vm_.code.push_back(code); + // type codes. + for (int tcode : arg_type_codes) { + code.v_int = tcode; + vm_.code.push_back(code); + } } int CodeGenStackVM::GetVarID(const Variable* v) const { @@ -162,7 +132,7 @@ void CodeGenStackVM::Push_(const ir::Load* op) { this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); - this->PushOp(StackVM::GetLoad(op->type)); + this->PushOp(StackVM::GetLoad(Type2TVMType(op->type))); } } void CodeGenStackVM::Push_(const ir::Store* op) { @@ -172,7 +142,7 @@ void CodeGenStackVM::Push_(const ir::Store* op) { this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); this->Push(op->value); - this->PushOp(StackVM::GetStore(op->value.type())); + this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type()))); } void CodeGenStackVM::Push_(const ir::Allocate* op) { @@ -231,22 +201,49 @@ void CodeGenStackVM::Push_(const ir::Call* op) { for (size_t i = 1; i < op->args.size(); ++i) { this->Push(op->args[i]); } + // find the fuction id. + const std::string& func_name = s->value; + auto it = global_fun_idmap_.find(func_name); + int fid; + if (it != global_fun_idmap_.end()) { + fid = it->second; + } else { + fid = static_cast(vm_.packed_func.size()); + PackedFunc f = PackedFunc::GetGlobal(func_name); + vm_.packed_func.push_back(f); + global_fun_idmap_[func_name] = fid; + } + // get the argument type code. + std::vector arg_type_codes; for (size_t i = 1; i < op->args.size(); ++i) { Type t = op->args[i].type(); int code = t.code(); - int bits = t.bits(); int lanes = t.lanes(); - int tcode = (code << (8 * 3)) | (bits << 16) | lanes; - this->PushOp(StackVM::PUSH_I64, tcode); + CHECK_EQ(lanes, 1); + arg_type_codes.push_back(code); } - int num_args = static_cast((op->args.size() - 1) * 2); - this->PushOp(StackVM::PUSH_I64, num_args); - this->PushOp(StackVM::CALL_EXTERN, GetGlobalFuncID(s->value)); + this->PushCallPacked(fid, arg_type_codes); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_I64); + } else if (op->call_type == Call::Extern && op->func.defined()) { + CHECK(op->func->is_type()); + LoweredFunc f(op->func.node_); + auto it = device_fun_idmap_.find(f); + CHECK(it != device_fun_idmap_.end()) + << "Cannot find device function " << f->name; + const int fid = it->second; + std::vector arg_type_codes(op->args.size()); + for (size_t i = 0; i < op->args.size(); ++i) { + this->Push(op->args[i]); + Type t = op->args[i].type(); + int lanes = t.lanes(); + CHECK_EQ(lanes, 1); + arg_type_codes[i] = t.code(); + } + this->PushCallPacked(fid, arg_type_codes); } else { this->HandleUnknownCall(op); } @@ -277,6 +274,8 @@ inline void PushBinary(StackVM::OpCode op_int64, } + + inline void PushCast(Type dst, Type src, CodeGenStackVM* p) { diff --git a/src/codegen/codegen_stack_vm.h b/src/codegen/codegen_stack_vm.h index b8ecf79612f7..6a81b3bd6b7f 100644 --- a/src/codegen/codegen_stack_vm.h +++ b/src/codegen/codegen_stack_vm.h @@ -7,17 +7,18 @@ #define TVM_CODEGEN_CODEGEN_STACK_VM_H_ #include -#include +#include #include #include +#include #include -#include "../jit/stack_vm.h" +#include "../runtime/stack_vm/stack_vm.h" namespace tvm { namespace codegen { -using jit::StackVM; +using runtime::StackVM; /*! * \brief A base class to generate a stack VM. @@ -26,13 +27,16 @@ using jit::StackVM; */ class CodeGenStackVM { public: - /*! + /*! * \brief Generate a stack VM representing * \param f The function to be compiled + * \param device_funcs The extern device functions to be linked. * \note Only call compile once, * create a new codegen object each time. */ - StackVM Compile(LoweredFunc f); + StackVM Compile( + LoweredFunc f, + const std::unordered_map& device_funcs); /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ @@ -49,6 +53,13 @@ class CodeGenStackVM { * \return operand_index, indicating location of operand */ int64_t PushOp(StackVM::OpCode opcode, int operand); + /*! + * \brief Push a call packed function. + * \param fid The function id. + * \param arg_type_codes The type codes of arguments. + */ + void PushCallPacked(int fid, + const std::vector& arg_type_codes); /*! * \brief Set the relative jump offset to be offset. * \param operand_index The indexed returned by PushOp. @@ -65,11 +76,6 @@ class CodeGenStackVM { * \return the id of the string. */ int GetStrID(const std::string& key); - /*! - * \brief Push the function to the VM and get a id. - * \param f The function to be pushed. - */ - int GetGlobalFuncID(std::string name); /*! * \brief Allocate a variable name for a newly defined var. * \param v The variable. @@ -101,8 +107,10 @@ class CodeGenStackVM { std::unordered_map var_idmap_; /*! \brief id of each string */ std::unordered_map str_idmap_; - /*! \brief id of each function */ - std::unordered_map fun_idmap_; + /*! \brief id of each global function */ + std::unordered_map global_fun_idmap_; + /*! \brief id of device function */ + std::unordered_map device_fun_idmap_; }; } // namespace codegen diff --git a/src/codegen/split_host_device.cc b/src/codegen/split_host_device.cc index 383e307a1501..213dd8a40dfc 100644 --- a/src/codegen/split_host_device.cc +++ b/src/codegen/split_host_device.cc @@ -5,7 +5,7 @@ */ #include #include -#include +#include #include #include #include @@ -169,6 +169,7 @@ class HostDeviceSplitter : public IRMutator { n->body = m.Mutate(body); n->name = os.str(); n->args = m.undefined_; + n->thread_axis = m.thread_axis_; CHECK_NE(m.thread_extent_.size(), 0U); // improve the handle data type diff --git a/src/base/saveload_json.cc b/src/lang/saveload_json.cc similarity index 90% rename from src/base/saveload_json.cc rename to src/lang/saveload_json.cc index 6a877caf4678..183866e1f2bd 100644 --- a/src/base/saveload_json.cc +++ b/src/lang/saveload_json.cc @@ -4,13 +4,44 @@ * \brief Utilities to save/load TVM objects. */ #include +#include #include #include #include -#include "./common.h" namespace tvm { +inline std::string Type2String(const Type& t) { + if (t.code() ==Type::Handle) return "handle"; + std::ostringstream os; + os << t; + return os.str(); +} + +inline Type String2Type(std::string s) { + std::istringstream is(s); + halide_type_code_t code = Type::Int; + if (s.substr(0, 3) == "int") { + code = Type::Int; s = s.substr(3); + } else if (s.substr(0, 4) == "uint") { + code = Type::UInt; s = s.substr(4); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else if (s == "handle") { + return Handle(); + } else { + LOG(FATAL) << "unknown type " << s; + } + int bits = 32, lanes = 1; + if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { + LOG(FATAL) << "unknown type " << s; + } + return Type(code, bits, lanes); +} + + // indexer to index all the ndoes class NodeIndexer : public AttrVisitor { public: diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 17fb97de3a0f..aea44946a4fc 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -193,7 +193,8 @@ int TVMFuncCall(TVMFunctionHandle func, (*static_cast(func)).CallPacked( TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. - if (rv.type_code() == kStr) { + if (rv.type_code() == kStr || + rv.type_code() == kTVMType) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); e->ret_str = rv.operator std::string(); *ret_type_code = kStr; diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h new file mode 100644 index 000000000000..b735064e01c4 --- /dev/null +++ b/src/runtime/cuda/cuda_common.h @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cuda_common.h + * \brief Common utilities for CUDA + */ +#ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_ +#define TVM_RUNTIME_CUDA_CUDA_COMMON_H_ + +#include +#include +#include + +#if TVM_CUDA_RUNTIME +#include + +namespace tvm { +namespace runtime { + +#define CUDA_DRIVER_CALL(x) \ + { \ + CUresult result = x; \ + if (result != CUDA_SUCCESS) { \ + const char *msg; \ + cuGetErrorName(result, &msg); \ + LOG(FATAL) \ + << "CUDAError: " #x " failed with error: " << msg; \ + } \ + } + +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + + +/*! + * \brief Compile code into ptx using NVRTC + * \param code The cuda code. + * \return The PTX code. + */ +std::string NVRTCCompile(const std::string& code); + +/*! + * \brief Automatically detect and set cuda device. + * \param args The arguments. + */ +inline void AutoSetCUDADevice(const TVMArgs& args) { + int dev_id = -1; + for (int i = 0; i < args.size(); ++i) { + if (args.type_codes[i] == kArrayHandle) { + TVMContext ctx = static_cast( + args.values[i].v_handle)->ctx; + CHECK_EQ(ctx.dev_mask, kGPU) + << "All operands need to be GPU"; + if (dev_id == -1) { + dev_id = ctx.dev_id; + } else { + CHECK_EQ(dev_id, ctx.dev_id) + << "Operands comes from different devices "; + } + } + } + CUDA_CALL(cudaSetDevice(dev_id)); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_CUDA_RUNTIME +#endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_ diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc new file mode 100644 index 000000000000..83ead0847598 --- /dev/null +++ b/src/runtime/cuda/cuda_module.cc @@ -0,0 +1,132 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cuda_module.cc + */ +#include "./cuda_module.h" + +#if TVM_CUDA_RUNTIME +#include +#include +#include +#include +#include +#include +#include "./cuda_common.h" +#include "../void_addr_args.h" +#include "../thread_axis_args.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Internal data structure to support multi-gpu execution. + * Try to use CUDA runtime's primary context. + */ +class CUDAModule::Internal { + public: + explicit Internal(std::string data) + : data_(data) { + std::fill(module_.begin(), module_.end(), nullptr); + } + // get a CUfunction from primary context in dev_id + CUfunction GetFunc(int dev_id, const std::string& func_name) { + std::lock_guard lock(mutex_); + // must recheck under the lock scope + if (module_[dev_id] == nullptr) { + CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[dev_id]), data_.c_str())); + } + CUfunction func; + CUresult result = cuModuleGetFunction(&func, module_[dev_id], func_name.c_str()); + if (result != CUDA_SUCCESS) { + const char *msg; + cuGetErrorName(result, &msg); + LOG(FATAL) + << "CUDAError: cuModuleGetFunction " << func_name + << " failed with error: " << msg; + } + return func; + } + // destructor + ~Internal() { + for (size_t i = 0; i < module_.size(); ++i) { + if (module_[i] != nullptr) { + CUDA_CALL(cudaSetDevice(i)); + CUDA_DRIVER_CALL(cuModuleUnload(module_[i])); + } + } + } + + private: + // the binary data + std::string data_; + // the internal modules per GPU, to be lazily initialized. + std::array module_; + // internal mutex when updating the module + std::mutex mutex_; +}; + +// a wrapped function class to get packed fucn. +class CUDAWrappedFunc { + public: + // initialize the CUDA function. + void Init(std::shared_ptr m, + const std::string& func_name, + size_t num_void_args, + const std::vector& thread_axis_tags) { + m_ = m; + func_name_ = func_name; + std::fill(fcache_.begin(), fcache_.end(), nullptr); + thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + } + // invoke the function with void arguments + void operator()(TVMArgs args, + TVMRetValue* rv, + void** void_args) const { + int dev_id; + CUDA_CALL(cudaGetDevice(&dev_id)); + if (fcache_[dev_id] == nullptr) { + fcache_[dev_id] = m_->GetFunc(dev_id, func_name_); + } + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + CUDA_DRIVER_CALL(cuLaunchKernel( + fcache_[dev_id], + wl.grid_dim(0), + wl.grid_dim(1), + wl.grid_dim(2), + wl.block_dim(0), + wl.block_dim(1), + wl.block_dim(2), + 0, nullptr, void_args, 0)); + } + + private: + // internal module + std::shared_ptr m_; + // The name of the function. + std::string func_name_; + // Device function cache per device. + // mark as mutable, to enable lazy initialization + mutable std::array fcache_; + // thread axis configuration + ThreadAxisConfig thread_axis_cfg_; +}; + +PackedFunc CUDAModule::GetPackedFunc( + const std::string& func_name, + const std::vector arg_types, + const std::vector thread_axis_tags) const { + CUDAWrappedFunc f; + f.Init(ptr_, func_name, arg_types.size(), thread_axis_tags); + return PackFromVoidAddrArgs(f, arg_types); +} + +CUDAModule CUDAModule::Create(std::string ptx) { + // call a runtime API to make sure the context is created. + CUDAModule m; + m.ptr_ = std::make_shared(ptx); + return m; +} +} // namespace runtime +} // namespace tvm + +#endif // TVM_CUDA_RUNTIME diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h new file mode 100644 index 000000000000..38c805eea9d8 --- /dev/null +++ b/src/runtime/cuda/cuda_module.h @@ -0,0 +1,50 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cuda_module.h + * \brief Execution handling of CUDA kernels + */ +#ifndef TVM_RUNTIME_CUDA_CUDA_MODULE_H_ +#define TVM_RUNTIME_CUDA_CUDA_MODULE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Handle execution of CUDA kernels as PackedFunc. + * It wraps around driver API to work with CUDA runtime API. + */ +class CUDAModule { + public: + /*! + * \brief Get CUDA Kernel launch wrapped as PackedFunc + * \param func_name The name of the function. + * \param arg_types The type of each argument in the function. + * \param thread_axis_tags The tag sequence of the thread axis. + */ + PackedFunc GetPackedFunc( + const std::string& func_name, + const std::vector arg_types, + const std::vector thread_axis_tags) const; + /*! + * \brief create a cuda module from data. + * \param data The module data. + */ + static CUDAModule Create(std::string data); + /*! \brief hidden internal data structure. */ + class Internal; + /*! \brief Maximum number of GPU supported in CUDAModule */ + static constexpr const int kMaxNumGPUs = 32; + + private: + std::shared_ptr ptr_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/device_api_gpu.h b/src/runtime/cuda/device_api_cuda.h similarity index 77% rename from src/runtime/device_api_gpu.h rename to src/runtime/cuda/device_api_cuda.h index b18a95dcb0a6..dc2954b026f5 100644 --- a/src/runtime/device_api_gpu.h +++ b/src/runtime/cuda/device_api_cuda.h @@ -1,33 +1,21 @@ /*! - * Copyright (c) 2016 by Contributors - * \file device_api_gpu.h + * Copyright (c) 2017 by Contributors + * \file device_api_cuda.h * \brief GPU specific API */ -#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_ -#define TVM_RUNTIME_DEVICE_API_GPU_H_ +#ifndef TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_ +#define TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_ -#include -#include "./device_api.h" +#include "./cuda_common.h" #if TVM_CUDA_RUNTIME + +#include #include namespace tvm { namespace runtime { -/*! - * \brief Protected CUDA call. - * \param func Expression to call. - * - * It checks for CUDA errors after invocation of the expression. - */ -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ - } - template<> inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) { CUDA_CALL(cudaSetDevice(ctx.dev_id)); @@ -94,4 +82,4 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream) { } // namespace runtime } // namespace tvm #endif // TVM_CUDA_RUNTIME -#endif // TVM_RUNTIME_DEVICE_API_GPU_H_ +#endif // TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_ diff --git a/src/runtime/cuda/nvrtc.cc b/src/runtime/cuda/nvrtc.cc new file mode 100644 index 000000000000..38c13e22e7fc --- /dev/null +++ b/src/runtime/cuda/nvrtc.cc @@ -0,0 +1,46 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nvrtc.cc + */ +#include "./cuda_common.h" + +#if TVM_CUDA_RUNTIME + +#include + +namespace tvm { +namespace runtime { + +#define NVRTC_CALL(x) \ + { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + LOG(FATAL) \ + << "NvrtcError: " #x " failed with error: " \ + << nvrtcGetErrorString(result); \ + } \ + } + +std::string NVRTCCompile(const std::string& code) { + nvrtcProgram prog; + NVRTC_CALL(nvrtcCreateProgram( + &prog, code.c_str(), nullptr, 0, nullptr, nullptr)); + nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, nullptr); + size_t log_size; + NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); + std::string log; log.resize(log_size); + NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); + CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; + size_t ptx_size; + NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); + + std::string ptx; + ptx.resize(ptx_size); + NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0])); + NVRTC_CALL(nvrtcDestroyProgram(&prog)); + return ptx; +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_CUDA_RUNTIME diff --git a/src/runtime/device_api.h b/src/runtime/device_api.h index 3ef1a7c0ead7..f551ca9ee8c8 100644 --- a/src/runtime/device_api.h +++ b/src/runtime/device_api.h @@ -109,7 +109,7 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream); } // namespace tvm #include "./device_api_cpu.h" -#include "./device_api_gpu.h" -#include "./device_api_opencl.h" +#include "./cuda/device_api_cuda.h" +#include "./opencl/device_api_opencl.h" #endif // TVM_RUNTIME_DEVICE_API_H_ diff --git a/src/runtime/device_api_opencl.h b/src/runtime/opencl/device_api_opencl.h similarity index 54% rename from src/runtime/device_api_opencl.h rename to src/runtime/opencl/device_api_opencl.h index 257262beb0d7..3d2c2d1b458d 100644 --- a/src/runtime/device_api_opencl.h +++ b/src/runtime/opencl/device_api_opencl.h @@ -1,138 +1,22 @@ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2017 by Contributors * \file device_api_opencl.h * \brief OpenCL specific API */ -#ifndef TVM_RUNTIME_DEVICE_API_OPENCL_H_ -#define TVM_RUNTIME_DEVICE_API_OPENCL_H_ +#ifndef TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_ +#define TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_ -#if TVM_OPENCL_RUNTIME - -#ifdef __APPLE__ -#include -#else -#include -#endif +#include -#include +#if TVM_OPENCL_RUNTIME #include #include - +#include "./opencl_common.h" namespace tvm { namespace runtime { namespace cl { -static_assert(sizeof(cl_mem) ==sizeof(void*), - "Required to store cl_mem inside void*"); - -inline const char* CLGetErrorString(cl_int error) { - switch (error) { - case CL_SUCCESS: return "CL_SUCCESS"; - case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND"; - case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE"; - case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE"; - case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; - case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES"; - case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY"; - case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE"; - case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP"; - case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH"; - case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; - case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE"; - case CL_MAP_FAILURE: return "CL_MAP_FAILURE"; - case CL_INVALID_VALUE: return "CL_INVALID_VALUE"; - case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE"; - case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM"; - case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE"; - case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT"; - case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES"; - case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE"; - case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR"; - case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT"; - case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; - case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE"; - case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER"; - case CL_INVALID_BINARY: return "CL_INVALID_BINARY"; - case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS"; - case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM"; - case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE"; - case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME"; - case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION"; - case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL"; - case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX"; - case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE"; - case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE"; - case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS"; - case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION"; - case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE"; - case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE"; - case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET"; - case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST"; - case CL_INVALID_EVENT: return "CL_INVALID_EVENT"; - case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION"; - case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT"; - case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE"; - case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; - default: return "Unknown OpenCL error code"; - } -} - -/*! - * \brief Protected OpenCL call - * \param func Expression to call. - */ -#define OPENCL_CHECK_ERROR(e) \ - { \ - CHECK(e == CL_SUCCESS) \ - << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ - } - -#define OPENCL_CALL(func) \ - { \ - cl_int e = (func); \ - OPENCL_CHECK_ERROR(e); \ - } - -// Process local opencl workspace -class OpenCLWorkspace { - public: - // global platform id - cl_platform_id platform_id; - // global context of this process - cl_context context{nullptr}; - // the devices - std::vector devices; - // the queues - std::vector queues; - // the mutex for initialization - std::mutex mu; - // destructor - ~OpenCLWorkspace() { - if (context != nullptr) { - OPENCL_CALL(clReleaseContext(context)); - } - } - // whether the workspace is initialized. - inline bool initialized() const { - return context != nullptr; - } - // get the queue of the context - cl_command_queue GetQueue(TVMContext ctx) const { - CHECK_EQ(ctx.dev_mask, kOpenCL); - CHECK(initialized()) - << "The OpenCL is not initialized"; - CHECK(ctx.dev_id >= 0 && static_cast(ctx.dev_id) < queues.size()) - << "Invalid OpenCL dev_id=" << ctx.dev_id; - return queues[ctx.dev_id]; - } - // get the global workspace - static OpenCLWorkspace* Global() { - static OpenCLWorkspace inst; - return &inst; - } -}; - inline std::string GetPlatformInfo( cl_platform_id pid, cl_platform_info param_name) { size_t ret_size; @@ -307,4 +191,4 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream) { } // namespace runtime } // namespace tvm #endif // TVM_OPENCL_RUNTIME -#endif // TVM_RUNTIME_DEVICE_API_OPENCL_H_ +#endif // TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h new file mode 100644 index 000000000000..e22bc3403389 --- /dev/null +++ b/src/runtime/opencl/opencl_common.h @@ -0,0 +1,184 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file opencl_common.h + * \brief OpenCL common header + */ +#ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ +#define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ + +#include +#include +#include +#include + +#if TVM_OPENCL_RUNTIME + +#ifdef __APPLE__ +#include +#else +#include +#endif + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace cl { + +static_assert(sizeof(cl_mem) ==sizeof(void*), + "Required to store cl_mem inside void*"); + +inline const char* CLGetErrorString(cl_int error) { + switch (error) { + case CL_SUCCESS: return "CL_SUCCESS"; + case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND"; + case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE"; + case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; + case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES"; + case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY"; + case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE"; + case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP"; + case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; + case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE"; + case CL_MAP_FAILURE: return "CL_MAP_FAILURE"; + case CL_INVALID_VALUE: return "CL_INVALID_VALUE"; + case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE"; + case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM"; + case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE"; + case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT"; + case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES"; + case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE"; + case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR"; + case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; + case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE"; + case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER"; + case CL_INVALID_BINARY: return "CL_INVALID_BINARY"; + case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS"; + case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM"; + case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE"; + case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME"; + case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION"; + case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL"; + case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX"; + case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE"; + case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE"; + case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS"; + case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION"; + case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE"; + case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE"; + case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET"; + case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST"; + case CL_INVALID_EVENT: return "CL_INVALID_EVENT"; + case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION"; + case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT"; + case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE"; + case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; + default: return "Unknown OpenCL error code"; + } +} + +/*! + * \brief Protected OpenCL call + * \param func Expression to call. + */ +#define OPENCL_CHECK_ERROR(e) \ + { \ + CHECK(e == CL_SUCCESS) \ + << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ + } + +#define OPENCL_CALL(func) \ + { \ + cl_int e = (func); \ + OPENCL_CHECK_ERROR(e); \ + } + +/*! + * \brief Process global OpenCL workspace. + */ +class OpenCLWorkspace { + public: + // global platform id + cl_platform_id platform_id; + // global context of this process + cl_context context{nullptr}; + // the devices + std::vector devices; + // the queues + std::vector queues; + // Number of registered kernels + // Used to register kernel into the workspace. + size_t num_registered_kernels{0}; + // the mutex for initialization + std::mutex mu; + // destructor + ~OpenCLWorkspace() { + if (context != nullptr) { + OPENCL_CALL(clReleaseContext(context)); + } + } + // whether the workspace is initialized. + inline bool initialized() const { + return context != nullptr; + } + // get the queue of the context + cl_command_queue GetQueue(TVMContext ctx) const { + CHECK_EQ(ctx.dev_mask, kOpenCL); + CHECK(initialized()) + << "The OpenCL is not initialized"; + CHECK(ctx.dev_id >= 0 && static_cast(ctx.dev_id) < queues.size()) + << "Invalid OpenCL dev_id=" << ctx.dev_id; + return queues[ctx.dev_id]; + } + // get the global workspace + static OpenCLWorkspace* Global(); +}; + +/*! \brief Thread local workspace */ +class OpenCLThreadEntry { + public: + /*! \brief The current context */ + TVMContext context; + /*! \brief The thread-local kernel table */ + std::vector kernel_table; + + OpenCLThreadEntry() { + context.dev_id = 0; + context.dev_mask = kOpenCL; + } + // get the global workspace + static OpenCLThreadEntry* ThreadLocal(); +}; +} // namespace cl +/*! + * \brief Automatically detect and set cuda device. + * \param args The arguments. + */ +inline void AutoSetOpenCLContext(const TVMArgs& args) { + // TODO(tqchen): merge this with CUDA logic. + int dev_id = -1; + for (int i = 0; i < args.size(); ++i) { + if (args.type_codes[i] == kArrayHandle) { + TVMContext ctx = static_cast( + args.values[i].v_handle)->ctx; + CHECK_EQ(ctx.dev_mask, kOpenCL) + << "All operands need to be GPU"; + if (dev_id == -1) { + dev_id = ctx.dev_id; + } else { + CHECK_EQ(dev_id, ctx.dev_id) + << "Operands comes from different devices "; + } + } + } + cl::OpenCLThreadEntry::ThreadLocal()->context.dev_id = dev_id; +} +} // namespace runtime +} // namespace tvm +#endif // TVM_OPENCL_RUNTIME +#endif // TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc new file mode 100644 index 000000000000..64bff819401c --- /dev/null +++ b/src/runtime/opencl/opencl_module.cc @@ -0,0 +1,164 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file opencl_module.cc + */ +#include "./opencl_common.h" +#include "./opencl_module.h" + +#if TVM_OPENCL_RUNTIME + +#include +#include +#include +#include "../void_addr_args.h" +#include "../thread_axis_args.h" + +namespace tvm { +namespace runtime { + +using namespace detail; + +/*! + * \brief Internal data structure to support multi-gpu execution. + * Try to use OpenCL runtime's primary context. + */ +class OpenCLModule::Internal { + public: + // the binary data + cl_program program; + // kernel id cache + std::unordered_map kid_map; + + explicit Internal(cl_program program) + : program(program) { + } + // destructor + ~Internal() { + OPENCL_CALL(clReleaseProgram(program)); + } + // get kernel id given key(function name. + size_t GetKernelID(const std::string& key) { + cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); + std::lock_guard lock(w->mu); + if (kid_map.count(key)) return kid_map.at(key); + size_t kid = w->num_registered_kernels++; + kid_map[key] = kid; + return kid; + } +}; + +class OpenCLWrappedFunc { + public: + // initialize the CUDA function. + void Init(std::shared_ptr m, + size_t kernel_id, + std::string func_name, + std::vector arg_size, + const std::vector& thread_axis_tags) { + m_ = m; + kernel_id_ = kernel_id; + func_name_ = func_name; + arg_size_ = arg_size; + thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); + } + // invoke the function with void arguments + void operator()(TVMArgs args, + TVMRetValue* rv, + void** void_args) const { + cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); + cl::OpenCLThreadEntry* t = cl::OpenCLThreadEntry::ThreadLocal(); + CHECK(w->initialized()); + // get the kernel from thread local kernel table. + if (kernel_id_ >= t->kernel_table.size()) { + t->kernel_table.resize(kernel_id_ + 1, nullptr); + } + cl_kernel kernel = t->kernel_table[kernel_id_]; + if (kernel == nullptr) { + cl_int err; + kernel = clCreateKernel(m_->program, func_name_.c_str(), &err); + OPENCL_CHECK_ERROR(err); + t->kernel_table[kernel_id_] = kernel; + } + // setup arguments. + for (cl_uint i = 0; i < arg_size_.size(); ++i) { + OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i])); + } + cl_command_queue queue = w->GetQueue(t->context); + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); + for (cl_uint i = 0; i < work_dim; ++i) { + wl.work_size[i + 3] *= wl.work_size[i]; + } + // launch kernel + OPENCL_CALL(clEnqueueNDRangeKernel( + queue, kernel, work_dim, nullptr, + wl.work_size + 3, + wl.work_size, + 0, nullptr, nullptr)); + } + + private: + // modulex + std::shared_ptr m_; + // global kernel id in the kernel table. + size_t kernel_id_; + // The name of the function. + std::string func_name_; + // convert code for void argument + std::vector arg_size_; + // thread axis config + ThreadAxisConfig thread_axis_cfg_; +}; + +PackedFunc OpenCLModule::GetPackedFunc( + const std::string& func_name, + const std::vector arg_types, + const std::vector thread_axis_tags) const { + OpenCLWrappedFunc f; + // get the kernel id. + size_t kid = ptr_->GetKernelID(func_name); + std::vector arg_size(arg_types.size()); + for (size_t i = 0; i < arg_types.size(); ++i) { + TVMType t = arg_types[i]; + CHECK_EQ(t.lanes, 1U); + uint32_t bits = t.bits; + CHECK_EQ(bits % 8, 0U); + arg_size[i] = bits / 8; + } + // initialize the wrapped func. + f.Init(ptr_, kid, func_name, arg_size, thread_axis_tags); + return PackFromVoidAddrArgs(f, arg_types); +} + +OpenCLModule OpenCLModule::CreateWithSource(std::string source) { + cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); + CHECK(w->initialized()); + const char* s = source.c_str(); + size_t len = source.length(); + cl_int err; + cl_program prog = clCreateProgramWithSource( + w->context, 1, &s, &len, &err); + OPENCL_CHECK_ERROR(err); + + for (cl_device_id dev_id : w->devices) { + err = clBuildProgram(prog, 1, &dev_id, nullptr, nullptr, nullptr); + if (err != CL_SUCCESS) { + size_t len; + std::string log; + clGetProgramBuildInfo( + prog, dev_id, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); + log.resize(len); + clGetProgramBuildInfo( + prog, dev_id, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); + LOG(FATAL) << "OpenCL build error for device=" << dev_id << log; + } + } + OpenCLModule m; + m.ptr_ = std::make_shared(prog); + return m; +} + +} // namespace runtime +} // namespace tvm + +#endif // TVM_OPENCL_RUNTIME diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h new file mode 100644 index 000000000000..e9d6fb5da395 --- /dev/null +++ b/src/runtime/opencl/opencl_module.h @@ -0,0 +1,48 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file opencl_module.h + * \brief Execution handling of OPENCL kernels + */ +#ifndef TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ +#define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Handle execution of OPENCL kernels as PackedFunc. + * It wraps around driver API to work with OPENCL runtime API. + */ +class OpenCLModule { + public: + /*! + * \brief Get OpenCL Kernel launch wrapped as PackedFunc + * \param func_name The name of the function. + * \param arg_types The type of each argument in the function. + * \param thread_axis_tags The tag sequence of the thread axis. + */ + PackedFunc GetPackedFunc( + const std::string& func_name, + const std::vector arg_types, + const std::vector thread_axis_tags) const; + /*! + * \brief create a OpenCL module from data. + * \param source The module data. + */ + static OpenCLModule CreateWithSource(std::string source); + /*! \brief hidden internal data structure. */ + class Internal; + + private: + std::shared_ptr ptr_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_workspace.cc b/src/runtime/opencl/opencl_workspace.cc new file mode 100644 index 000000000000..1f79f4280bb6 --- /dev/null +++ b/src/runtime/opencl/opencl_workspace.cc @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file opencl_workspace.cc + */ +#include "./opencl_common.h" + +#if TVM_OPENCL_RUNTIME + +#include + +namespace tvm { +namespace runtime { +namespace cl { + +OpenCLWorkspace* OpenCLWorkspace::Global() { + static OpenCLWorkspace inst; + return &inst; +} + +typedef dmlc::ThreadLocalStore OpenCLThreadStore; + +OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { + return OpenCLThreadStore::Get(); +} + +} // namespace cl +} // namespace runtime +} // namespace tvm + +#endif // TVM_OPENCL_RUNTIME diff --git a/src/jit/stack_vm.cc b/src/runtime/stack_vm/stack_vm.cc similarity index 88% rename from src/jit/stack_vm.cc rename to src/runtime/stack_vm/stack_vm.cc index 80c4bcbf0513..f631dab0130c 100644 --- a/src/jit/stack_vm.cc +++ b/src/runtime/stack_vm/stack_vm.cc @@ -7,7 +7,7 @@ #include "./stack_vm.h" namespace tvm { -namespace jit { +namespace runtime { typedef dmlc::ThreadLocalStore StackVMStateStore; @@ -126,7 +126,6 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { STACK_VM_PRINT_CODE0(SELECT); STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP); STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP); - STACK_VM_PRINT_CODE1(CALL_EXTERN); STACK_VM_PRINT_CODE1(ASSERT); STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE); STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE); @@ -143,6 +142,22 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES); + // packed function. + case CALL_PACKED_FUNC: { + int num_args = code[pc + 1].v_int; + os << "[" << pc << "]\tCALL_PACKED_FUNC " + << " num_args=" << num_args + << " fid=" << code[pc + 2].v_int; + os << " type_codes:"; + for (int i = 0; i < num_args; ++i) { + os << ' ' << code[pc + 3 + i].v_int; + } + os << '\n'; + for (int i = 0; i < num_args + 2; ++i) { + os << "[" << pc + 1 << "]" << std::endl; + } + return pc + 3 + num_args; + } } LOG(FATAL) << "unknown op code " << code[pc].op_code; return 0; @@ -160,6 +175,19 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) return os; } +void StackVM::operator()(const runtime::TVMArgs& args) const { + StackVM::State* s = StackVM::ThreadLocalState(); + s->sp = 0; + s->pc = 0; + if (s->heap.size() < this->heap_size) { + s->heap.resize(this->heap_size); + } + s->heap[0].v_handle = (void*)args.values; // NOLINT(*) + s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) + s->heap[2].v_int64 = args.num_args; + this->Run(s); +} + void StackVM::Run(State* s) const { int64_t sp = s->sp; int64_t pc = s->pc; @@ -174,7 +202,6 @@ void StackVM::Run(State* s) const { heap.resize(heap_size); } const int64_t code_size = static_cast(code.size()); - while (pc < code_size) { switch (code[pc].op_code) { case ADD_I64: STACK_VM_BINOP(+, v_int64); break; @@ -252,13 +279,19 @@ void StackVM::Run(State* s) const { pc += 2; break; } - case CALL_EXTERN: { - int num_args = static_cast(stack[sp].v_int64); - int call_fid = code[pc + 1].v_int; - stack[sp - num_args] = extern_func[call_fid]( - &stack[sp - num_args], num_args); - sp = sp - num_args; - pc += 2; + case CALL_PACKED_FUNC: { + // call packed function. + int num_args = code[pc + 1].v_int; + int call_fid = code[pc + 2].v_int; + static_assert(sizeof(Code) == sizeof(int) && + alignof(Code) == alignof(int), "asusmption"); + const int* type_codes = &(code[pc].v_int) + 3; + runtime::TVMRetValue rv; + packed_func[call_fid].CallPacked( + runtime::TVMArgs(&stack[sp + 1 - num_args], type_codes, num_args), &rv); + sp = sp + 1 - num_args; + stack[sp] = rv.value(); + pc += 3 + num_args; break; } case ASSERT: { @@ -331,5 +364,5 @@ void StackVM::Run(State* s) const { } } -} // namespace jit +} // namespace runtime } // namespace tvm diff --git a/src/jit/stack_vm.h b/src/runtime/stack_vm/stack_vm.h similarity index 84% rename from src/jit/stack_vm.h rename to src/runtime/stack_vm/stack_vm.h index b2d3c62ce721..d45f857eedc9 100644 --- a/src/jit/stack_vm.h +++ b/src/runtime/stack_vm/stack_vm.h @@ -5,24 +5,29 @@ * * This can be used to interepret host side code * to setup calls into device functions - * when only JIT for device is available(via NVRTC or OpenCL). + * when only Runtime compilation for device is available(via NVRTC or OpenCL). */ -#ifndef TVM_JIT_STACK_VM_H_ -#define TVM_JIT_STACK_VM_H_ +#ifndef TVM_RUNTIME_STACK_VM_STACK_VM_H_ +#define TVM_RUNTIME_STACK_VM_STACK_VM_H_ -#include #include +#include #include #include namespace tvm { -namespace jit { +namespace runtime { /*! * \brief A simple stack-based virtual machine. */ class StackVM { public: + /*! + * \brief Invoke the StackVM as PackedFunc + * \param args The arguments to the StackVM. + */ + void operator()(const TVMArgs& args) const; /*! * \brief The opcode of stack vm * \note Notation @@ -121,16 +126,19 @@ class StackVM { */ SELECT, /*! - * \brief call an extern function + * \brief call an extern packed function * \code * num_args = stack[sp].v_int64; * call_fid = code[pc + 1].v_int; * f = extern_func[call_fid]; - * stack[sp - num_args] = f(&stack[sp - num_args], num_args); + * int* type_codes = &(code[pc + 2].v_int) + * stack[sp - num_args] = f(&stack[sp - num_args], type_codes, num_args); * sp = sp - num_args; + * // The type codes are hidden in the code space. + * pc = pc + 2 + num_args * \endcode */ - CALL_EXTERN, + CALL_PACKED_FUNC, /*! * \brief Assert condition is true. * \code @@ -217,14 +225,12 @@ class StackVM { int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*) /*! \brief Get thread local state of the stack VM */ static State* ThreadLocalState(); - /*! \brief extern function that will mutate the state */ - using ExternFunc = std::function; /*! \brief The instructions */ std::vector code; /*! \brief constant error messages */ std::vector str_data; - /*! \brief Extern functions */ - std::vector extern_func; + /*! \brief Extern functions in packed func format */ + std::vector packed_func; /*! \brief name of each heap id*/ std::vector heap_id_name; /*! \brief The memory size needed */ @@ -254,20 +260,20 @@ class StackVM { * \param t the type code. * \return The load opcode */ - static OpCode GetLoad(Type t) { - CHECK_EQ(t.lanes(), 1); - if (t.is_handle()) return ADDR_LOAD_HANDLE; - if (t.is_int()) { - switch (t.bits()) { + static OpCode GetLoad(TVMType t) { + CHECK_EQ(t.lanes, 1U); + if (t.code == kHandle) return ADDR_LOAD_HANDLE; + if (t.code == kInt) { + switch (t.bits) { case 32 : return ADDR_LOAD_INT32; case 64 : return ADDR_LOAD_INT64; } - } else if (t.is_uint()) { - switch (t.bits()) { + } else if (t.code == kUInt) { + switch (t.bits) { case 32 : return ADDR_LOAD_UINT32; } - } else if (t.is_float()) { - switch (t.bits()) { + } else if (t.code == kFloat) { + switch (t.bits) { case 64 : return ADDR_LOAD_FP64; } } @@ -279,20 +285,19 @@ class StackVM { * \param t the type code. * \return The load opcode */ - static OpCode GetStore(Type t) { - CHECK_EQ(t.lanes(), 1); - if (t.is_int()) { - switch (t.bits()) { + static OpCode GetStore(TVMType t) { + CHECK_EQ(t.lanes, 1U); + if (t.code == kInt) { + switch (t.bits) { case 64 : return ADDR_STORE_INT64; } } LOG(FATAL) << "Cannot store type " << t; return ADDR_LOAD_FP64; } - friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*) }; -} // namespace jit +} // namespace runtime } // namespace tvm -#endif // TVM_JIT_STACK_VM_H_ +#endif // TVM_RUNTIME_STACK_VM_STACK_VM_H_ diff --git a/src/runtime/thread_axis_args.h b/src/runtime/thread_axis_args.h new file mode 100644 index 000000000000..96b34eaddece --- /dev/null +++ b/src/runtime/thread_axis_args.h @@ -0,0 +1,106 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file thread_axis_args.h + * \brief Extract thread axis configuration from TVMArgs. + */ +#ifndef TVM_RUNTIME_THREAD_AXIS_ARGS_H_ +#define TVM_RUNTIME_THREAD_AXIS_ARGS_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief workload speccification */ +struct ThreadWorkLoad { + // array, first three are thread configuration. + size_t work_size[6]; + /*! + * \param i The block dimension. + * \return i-th block dim + */ + inline size_t block_dim(size_t i) const { + return work_size[i]; + } + /*! + * \param i The grid dimension. + * \return i-th grid dim + */ + inline size_t grid_dim(size_t i) const { + return work_size[i + 3]; + } +}; +/*! \brief Thread axis configuration */ +class ThreadAxisConfig { + public: + void Init(size_t base, + const std::vector& thread_axis_tags) { + base_ = base; + std::vector filled(6, false); + for (size_t i = 0; i < thread_axis_tags.size(); ++i) { + const std::string& tag = thread_axis_tags[i]; + if (tag == "threadIdx.x") { + arg_index_map_.push_back(0); + filled[0] = true; + } else if (tag == "threadIdx.y") { + arg_index_map_.push_back(1); + filled[1] = true; + } else if (tag == "threadIdx.z") { + arg_index_map_.push_back(2); + filled[2] = true; + } else if (tag == "blockIdx.x") { + arg_index_map_.push_back(3 + 0); + filled[3] = true; + } else if (tag == "blockIdx.y") { + arg_index_map_.push_back(3 + 1); + filled[3 + 1] = true; + } else if (tag == "blockIdx.z") { + arg_index_map_.push_back(3 + 2); + filled[3 + 2] = true; + } else { + LOG(FATAL) << "do not known thread_tag=" << tag; + } + } + work_dim_ = 3; + for (int i = 0; i < 3; ++i) { + if (!filled[i]) { + for (int j = i; j < 3; ++j) { + CHECK(!filled[j] && !filled[j + 3]) + << "Invalid thread group configuration"; + } + work_dim_ = i; + break; + } else { + CHECK(filled[i]) + << "Must have both threadIdx and blockIdx"; + } + } + } + // extract workload from arguments. + ThreadWorkLoad Extract(TVMArgs x) const { + ThreadWorkLoad w; + std::fill(w.work_size, w.work_size + 6, 1); + for (size_t i = 0; i < arg_index_map_.size(); ++i) { + w.work_size[arg_index_map_[i]] = + static_cast(x.values[base_ + i].v_int64); + } + return w; + } + // return the work dim + size_t work_dim() const { + return work_dim_; + } + + private: + /*! \brief base axis */ + size_t base_; + /*! \brief The worker dimension */ + size_t work_dim_; + /*! \brief The index mapping. */ + std::vector arg_index_map_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_THREAD_AXIS_ARGS_H_ diff --git a/src/runtime/void_addr_args.h b/src/runtime/void_addr_args.h new file mode 100644 index 000000000000..6f627339e8df --- /dev/null +++ b/src/runtime/void_addr_args.h @@ -0,0 +1,164 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file void_addr_args.h + * \brief Utility to convert TVMArgs to void* array type-erasure function call. + * + * Array of argument address is a typical way of type-erasure for functions. + * The function signiture looks like function(void** args, int num_args); + * Where args takes the address of each input. + */ +#ifndef TVM_RUNTIME_VOID_ADDR_ARGS_H_ +#define TVM_RUNTIME_VOID_ADDR_ARGS_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Create a packed function from void addr types + * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) + * \param arg_types The arguments that wish to get from + * \tparam T the function type + * + * \return The wrapped packed function. + */ +template +inline PackedFunc PackFromVoidAddrArgs( + F f, const std::vector& arg_types); + +// implementations details +namespace detail { +/*! + * \brief void addr argument data content + * holder in case conversion is needed. + */ +union VoidArgHolder { + int32_t v_int32; + uint32_t v_uint32; + float v_float32; +}; + +template +class VoidAddrArray { + public: + explicit VoidAddrArray(int num_args) { + } + void** addr() { + return addr_; + } + VoidArgHolder* holder() { + return holder_; + } + + private: + void* addr_[MAX_NARG]; + VoidArgHolder holder_[MAX_NARG]; +}; + +template<> +class VoidAddrArray<0> { + public: + explicit VoidAddrArray(int num_args) + : addr_(num_args), holder_(num_args) { + } + void** addr() { + return addr_.data(); + } + VoidArgHolder* holder() { + return holder_.data(); + } + + private: + std::vector addr_; + std::vector holder_; +}; + +/*! \brief conversion code used in void arg. */ +enum VoidArgConvertCode { + INT64_TO_INT64, + INT64_TO_INT32, + INT64_TO_UINT32, + FLOAT64_TO_FLOAT32, + FLOAT64_TO_FLOAT64, + HANDLE_TO_HANDLE +}; + +template +inline PackedFunc PackFromVoidAddrArgs_( + F f, const std::vector& codes) { + int num_args = static_cast(codes.size()); + auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { + VoidAddrArray temp(num_args); + void** addr = temp.addr(); + VoidArgHolder* holder = temp.holder(); + for (int i = 0; i < num_args; ++i) { + switch (codes[i]) { + case INT64_TO_INT64: + case FLOAT64_TO_FLOAT64: + case HANDLE_TO_HANDLE: { + addr[i] = (void*)&(args.values[i]); // NOLINT(*) + break; + } + case INT64_TO_INT32: { + holder[i].v_int32 = static_cast(args.values[i].v_int64); + addr[i] = &(holder[i]); + break; + } + case INT64_TO_UINT32 : { + holder[i].v_uint32 = static_cast(args.values[i].v_int64); + addr[i] = &(holder[i]); + break; + } + case FLOAT64_TO_FLOAT32: { + holder[i].v_float32 = static_cast(args.values[i].v_float64); + addr[i] = &(holder[i]); + break; + } + } + } + f(args, ret, addr); + }; + return PackedFunc(ret); +} + +inline VoidArgConvertCode GetVoidArgConvertCode(TVMType t) { + CHECK_EQ(t.lanes, 1U); + if (t.code == kInt) { + if (t.bits == 64U) return INT64_TO_INT64; + if (t.bits == 32U) return INT64_TO_INT32; + } else if (t.code == kUInt) { + if (t.bits == 32U) return INT64_TO_UINT32; + } else if (t.code == kFloat) { + if (t.bits == 64U) return FLOAT64_TO_FLOAT64; + if (t.bits == 32U) return FLOAT64_TO_FLOAT32; + } else if (t.code == kHandle) { + return HANDLE_TO_HANDLE; + } + LOG(FATAL) << "Cannot handle " << t; + return HANDLE_TO_HANDLE; +} + +} // namespace detail + +template +inline PackedFunc PackFromVoidAddrArgs( + F f, const std::vector& arg_types) { + std::vector codes(arg_types.size()); + for (size_t i = 0; i < arg_types.size(); ++i) { + codes[i] = detail::GetVoidArgConvertCode(arg_types[i]); + } + size_t num_void_args = arg_types.size(); + // specialization + if (num_void_args <= 4) { + return detail::PackFromVoidAddrArgs_<4>(f, codes); + } else if (num_void_args <= 8) { + return detail::PackFromVoidAddrArgs_<8>(f, codes); + } else { + return detail::PackFromVoidAddrArgs_<0>(f, codes); + } +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VOID_ADDR_ARGS_H_ diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py deleted file mode 100644 index 1c23122ca01b..000000000000 --- a/tests/python/test_codegen_cuda.py +++ /dev/null @@ -1,38 +0,0 @@ -import tvm -import numpy - -def mock_test_add(): - """Not yet working, mock design""" - n = tvm.Var('n') - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - s = tvm.Schedule(C.op) - - # GPU schedule have to split by gridIdx and threadIdx - num_thread = 256 - grid_x = tvm.IterVar(thread_tag="gridIdx.x") - thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") - _, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) - _, x = s[C].split(x, outer=thread_x) - - # compile to IR - bounds = tvm.schedule.InferBound(s) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) - - Ab = tvm.Buffer(A.shape, A.dtype, name='A') - Bb = tvm.Buffer(B.shape, B.dtype, name='B') - Cb = tvm.Buffer(C.shape, C.dtype, name='C') - stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) - stmt = tvm.ir_pass.Simplify(stmt) - print(stmt) - output_ssa = False - f = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 1) - - f_list = tvm.codegen.SplitHostDevice(f) - for x in f_list: - code = tvm.codegen.CompileToC(x, output_ssa) - print(code) - -if __name__ == "__main__": - mock_test_add() diff --git a/tests/python/test_codegen_device.py b/tests/python/test_codegen_device.py new file mode 100644 index 000000000000..e8aba60a9af8 --- /dev/null +++ b/tests/python/test_codegen_device.py @@ -0,0 +1,83 @@ +import tvm +import numpy as np + +def test_add_pipeline(): + """Not yet working, mock design""" + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.Schedule(C.op) + + # GPU schedule have to split by gridIdx and threadIdx + num_thread = 256 + grid_x = tvm.IterVar(thread_tag="blockIdx.x") + thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") + _, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) + _, x = s[C].split(x, outer=thread_x) + + # compile to IR + bounds = tvm.schedule.InferBound(s) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) + + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + Bb = tvm.Buffer(B.shape, B.dtype, name='B') + Cb = tvm.Buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + stmt = tvm.ir_pass.Simplify(stmt) + fapi = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3) + fsplits = tvm.codegen.SplitHostDevice(fapi) + + def check_cuda(): + output_ssa = False + for f in fsplits[1:]: + print(tvm.codegen.CompileToC(f, output_ssa, "cuda")) + + # build and invoke the kernel. + fcuda = tvm.codegen.BuildNVRTC(fsplits, "stackvm") + num_device = 1 + for i in range(num_device): + ctx = tvm.gpu(i) + if not ctx.enabled: + continue + # launch the kernel. + n = 1027 + a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) + fcuda(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + + def check_opencl(): + output_ssa = False + for f in fsplits[1:]: + print(tvm.codegen.CompileToC(f, output_ssa, "opencl")) + + # build and invoke the kernel. + fcl = tvm.codegen.BuildOpenCL(fsplits, "stackvm") + # Disable OpenCL runtime test for now, + # since the local worksize on CPU might be too large. + num_device = 0 + for i in range(num_device): + ctx = tvm.cl(i) + if not ctx.enabled: + continue + # launch the kernel. + n = 1027 + a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) + fcl(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + + tvm.init_opencl() + if tvm.cl(0).enabled: + check_opencl() + + if tvm.gpu(0).enabled: + check_cuda() + +if __name__ == "__main__": + test_add_pipeline() diff --git a/tests/python/test_runtime_packed_func.py b/tests/python/test_runtime_packed_func.py index 3332e9a3cdb7..bd9eada79b6d 100644 --- a/tests/python/test_runtime_packed_func.py +++ b/tests/python/test_runtime_packed_func.py @@ -1,16 +1,6 @@ import tvm import numpy as np -def test_function(): - ctx = tvm.cpu(0) - x = np.random.randint(0, 10, size=(3, 4)) - x = np.array(x) - y = tvm.nd.array(x, ctx=ctx) - - f = tvm.codegen.DummyHelloFunction() - f(y, 10) - - def test_get_global(): targs = (10, 10.0, "hello") # register into global function table diff --git a/tests/python/test_jit_stack_vm.py b/tests/python/test_runtime_stack_vm.py similarity index 96% rename from tests/python/test_jit_stack_vm.py rename to tests/python/test_runtime_stack_vm.py index d5c3bd712520..435df5faadf2 100644 --- a/tests/python/test_jit_stack_vm.py +++ b/tests/python/test_runtime_stack_vm.py @@ -10,6 +10,7 @@ def test_stack_vm_basic(): a = tvm.nd.array(np.zeros(10, dtype='float32')) @tvm.register_func def tvm_call_back_get_shape(shape0): + print(shape0) assert shape0 == a.shape[0] n = tvm.Var('n') @@ -74,3 +75,5 @@ def test_stack_vm_cond(): if __name__ == "__main__": test_stack_vm_cond() + test_stack_vm_loop() + test_stack_vm_basic()