From 5351e87bfe1e1d5ebb417c68852c4a7ded762bb3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 25 Sep 2018 21:41:28 +0000 Subject: [PATCH 1/3] support of multiple devices for tvm.build --- nnvm/include/nnvm/node.h | 6 + nnvm/python/nnvm/compiler/build_module.py | 16 +- nnvm/src/compiler/graph_compile.cc | 59 +++++- nnvm/src/compiler/graph_fuse.h | 5 + nnvm/src/pass/plan_memory.cc | 8 +- python/tvm/build_module.py | 196 ++++++++++-------- .../unittest/test_runtime_heterogeneous.py | 30 +-- 7 files changed, 194 insertions(+), 126 deletions(-) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index ae782f04965e..1e7e5562402a 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -106,6 +106,12 @@ struct NodeAttrs { * stateful operators. */ std::vector > subgraphs; + /*! + * \brief Device information of the node. It indicates the device that this + * node should be executed. By default, the node is not annotated with any + * device. + */ + int device_type = 0; }; /*! diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 6fab4460b427..fc8f36143134 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -6,6 +6,7 @@ import tvm from tvm.contrib import graph_runtime +from tvm._ffi import runtime_ctypes from tvm import autotvm from . import graph_attr, graph_util from .. import graph as _graph @@ -118,10 +119,10 @@ def _lower(sch, inputs, func_name, graph): @tvm.register_func("nnvm.compiler.build_target") -def _build(funcs, target, target_host): +def _build(target_funcs, target_host): if target_host == "": target_host = None - return tvm.build(funcs, target=target, target_host=target_host) + return tvm.build(target_funcs, target_host=target_host) def _update_shape_dtype(shape, dtype, params): @@ -291,7 +292,16 @@ def build(graph, target=None, shape=None, dtype="float32", graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply("InferShape") graph = graph_attr.set_dtype_inputs(graph, dtype) - graph._set_json_attr("target", str(target), "str") + targets = [] + if isinstance(target, (str, tvm.target.Target)): + targets = [str(target)] + elif isinstance(target, dict): + device_types = [runtime_ctypes.STR2MASK(dev) for dev in + target.keys()] + graph._set_json_attr("device_type", device_types, "list_int") + targets = [str(tar) for tar in target.values()] + graph._set_json_attr("target", targets, "list_str") + if target_host is not None: graph._set_json_attr("target_host", str(target_host), "str") if cfg.pass_enabled("OpFusion"): diff --git a/nnvm/src/compiler/graph_compile.cc b/nnvm/src/compiler/graph_compile.cc index 3316f3932e27..db286c5784c8 100644 --- a/nnvm/src/compiler/graph_compile.cc +++ b/nnvm/src/compiler/graph_compile.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -77,8 +78,20 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { CHECK(g.HasAttr("fused_entry")) << "Fusion hasn't been applied yet."; FuseEntryVec fuse_entries = g.GetAttr("fused_entry"); - std::string target = g.GetAttr("target"); std::string target_host; + const std::vector& target = + g.GetAttr>("target"); + std::vector device_types{0}; + if (g.HasAttr("device_type")) { + device_types = g.GetAttr>("device_type"); + } + CHECK_EQ(target.size(), device_types.size()) + << "The number of compilation target doesn't match the given number of " + "devices."; + DeviceTargetMap dev_target_map; + for (size_t i = 0; i < device_types.size(); i++) { + dev_target_map.emplace(std::make_pair(device_types[i], target[i])); + } if (g.HasAttr("target_host")) { target_host = g.GetAttr("target_host"); @@ -87,7 +100,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { const nnvm::Op* assign_op = nnvm::Op::Get("_assign"); // Start lowering. - Array func_list; + std::unordered_map> tar_func_map; std::unordered_set func_set; const IndexedGraph& idx = g.indexed_graph(); @@ -95,9 +108,10 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; int root_id = group_vec[nid]; - if (static_cast(nid) != root_id) continue; + if (static_cast(nid) != root_id) continue; int master = master_vec[root_id]; FuseEntry& fe = fuse_entries[root_id]; + fe.device_type = inode.source->attrs.device_type; const IndexedGraph& subidx = fe.subgraph.indexed_graph(); CHECK_EQ(subidx.input_nodes().size(), fe.imap.size()); @@ -117,17 +131,20 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { } } CHECK_NE(sub_master_idx, -1) << "A master node not found in the subgraph."; - fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx); + CHECK(dev_target_map.count(fe.device_type)) + << "Cannot find the compilation target for device " << fe.device_type; + const auto& cur_target = dev_target_map[fe.device_type]; + fe.compiled_func = + GraphLower(fe.subgraph, inputs, cur_target, sub_master_idx); for (LoweredFunc f : fe.compiled_func->funcs) { if (!func_set.count(f.get())) { func_set.insert(f.get()); - func_list.push_back(f); + tar_func_map[cur_target].push_back(f); } } } const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op"); - std::unordered_map old_new; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; @@ -135,6 +152,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { // Only copy name since that is sufficient. nnvm::NodePtr np = nnvm::Node::Create(); np->attrs.name = inode.source->attrs.name; + np->attrs.device_type = inode.source->attrs.device_type; old_new[nid] = np; continue; } @@ -147,6 +165,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { nnvm::NodePtr np = nnvm::Node::Create(); np->attrs.op = tvm_op; np->attrs.name = inode.source->attrs.name; + np->attrs.device_type = inode.source->attrs.device_type; TVMOpParam param; param.func_name = fe.compiled_func->func_name; param.num_inputs = static_cast(fe.imap.size()); @@ -161,7 +180,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { auto rit = fe.reverse_imap.find(subidx[sub_input_id].source); CHECK(rit != fe.reverse_imap.end()); const IndexedGraph::NodeEntry& e = rit->second; - auto it = old_new.find(e.node_id); + auto it = old_new.find(e.node_id); CHECK(it != old_new.end()) << "cannot find node_id=" << e.node_id; np->inputs.emplace_back( @@ -174,6 +193,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { } old_new[nid] = np; } + nnvm::Graph ret; for (const auto& e : idx.outputs()) { auto it = old_new.find(group_vec[e.node_id]); @@ -197,7 +217,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { // // assign is a special operator that mutates the variable. // Currently assign is implemented as output = copy(input[1]) - // Then we run DecorageMemoryPlan to force + // Then we run DecorateMemoryPlan to force // output.storage = input[0].storage // std::vector assign_flag(new_idx.num_nodes(), 0); @@ -238,9 +258,28 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { ret.attrs["dtype"] = std::make_shared(std::move(new_dtype_vec)); ret.attrs["dltype"] = std::make_shared(std::move(new_dltype_vec)); - // Setup module + // Setup device assignment for heterogeneous execution. + if (tar_func_map.size() > 1) { + DeviceVector device_vec(new_idx.num_node_entries(), 0); + for (size_t i = 0; i < new_idx.num_nodes(); i++) { + device_vec[new_idx.entry_id(i, 0)] = + static_cast(new_idx[i].source->attrs.device_type); + } + for (uint32_t nid = 0; nid < new_idx.num_nodes(); nid++) { + const auto& inode = new_idx[nid]; + for (const auto& e : inode.inputs) { + device_vec[new_idx.entry_id(e)] = + static_cast(new_idx[e.node_id].source->attrs.device_type); + } + } + ret.attrs["device_index"] = std::make_shared(std::move(device_vec)); + } + // Setup module. static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target"); - tvm::runtime::Module module = fbuild(func_list, target, target_host); + tvm::runtime::Module module = + fbuild(tvm::Map>( + tar_func_map.begin(), tar_func_map.end()), + target_host); ret.attrs["module"] = std::make_shared(std::move(module)); ret = nnvm::ApplyPass(ret, "PlanMemory"); ret = DecorateMemoryPlan(ret, assign_flag); diff --git a/nnvm/src/compiler/graph_fuse.h b/nnvm/src/compiler/graph_fuse.h index 6faac7d3e162..abbb0257dbb0 100644 --- a/nnvm/src/compiler/graph_fuse.h +++ b/nnvm/src/compiler/graph_fuse.h @@ -7,6 +7,7 @@ #define NNVM_COMPILER_GRAPH_FUSE_H_ #include +#include #include #include "compile_engine.h" @@ -60,6 +61,8 @@ struct FuseEntry { bool flatten_data; // The corresponding function. GraphFunc compiled_func; + // The device that the fused op lowered to. + int device_type; }; // GroupVec stores the root node ids of the fused nodes. @@ -74,6 +77,8 @@ using FuseEntryVec = std::vector; // PatternVec stores operator patterns. using PatternVec = std::vector; +// DeviceTargetMap stores the device type to compilation target mapping info. +using DeviceTargetMap = std::unordered_map; } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index e0788386e6ea..f7ebcf5f2cf7 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -187,8 +187,8 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, const DTypeVector& dtype_vec = ret.GetAttr("dtype"); const DeviceVector* device_vec = nullptr; - if (ret.attrs.count("device") != 0) { - device_vec = &(ret.GetAttr("device")); + if (ret.attrs.count("device_index") != 0) { + device_vec = &(ret.GetAttr("device_index")); } size_t num_not_allocated = 0; std::vector storage_ref_count(idx.num_node_entries(), 0); @@ -237,8 +237,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, } } } - // normal allocation - const int dev_id = (device_vec != nullptr) ? device_vec->at(nid) : 0; // sort output nodes based on size before allocating output std::multimap eids; for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { @@ -253,6 +251,8 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { uint32_t eid = rit->second; + // normal allocation + int dev_id = (device_vec != nullptr) ? device_vec->at(eid) : 0; auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); if (sid >= 0) { storage_ref_count[sid] = entry_ref_count[eid]; diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 8e0d16286d6a..492ed97b7b02 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -379,24 +379,40 @@ def lower(sch, return stmt return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) + def build(sch, args=None, target=None, target_host=None, name="default_function", - binds=None, - postpone_host_codegen=False): + binds=None): """Build a function with arguments as signature. Code will be generated - for a device specified by the target. For homogeneous execution, a module - that contains both host and device code is returned. For heterogeneous - execution, a list of lowered functions for the host and a module containing - device code are returned, but actual code generation for the host module is - postponed after code generation is finished for all devices. + for devices coupled with target information. + + There are two typical uses of this functioin: + 1. For backward compatibility + n = 2 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + f = tvm.lower(s, [A, B, C], name="test_add") + m = tvm.build(f, target="llvm") + + 2. For multi-device execution: + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s1 = tvm.create_schedule(C.op) + s2 = topi.cpp.cuda.schedule_injective("cuda", [C]) + f1 = tvm.lower(s1, [A, B, C], name="test_add1") + f2 = tvm.lower(s2, [A, B, C], name="test_add2") + m = tvm.build({"llvm":[f1], "cuda":[f2]}, target_host="llvm") Parameters ---------- - sch : tvm.Schedule, or LoweredFunc - The schedule to be builded + sch : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list + The schedule to be built args : list of Buffer or Tensor or Var, optional The argument lists to the function. @@ -420,18 +436,10 @@ def build(sch, Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. - postpone_host_codegen : bool, optional - A bool value that indicates if code generation for the host module - should be postponed. This variable is set to be true for heterogeneous - execution. Otherwise, it is defaulted to false. - Returns ------- ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple - A module that combines both host and device code is returned when - postpone_host_codegen is not set. Otherwise, a list of lowered - functions for the host and a module contains only device code are - returned. + A module that combines both host and device code. Note ---- @@ -451,76 +459,96 @@ def build(sch, flist = [sch] elif isinstance(sch, (list, tuple, container.Array)): flist = sch + elif not isinstance(sch, (dict, container.Map)): + raise ValueError("sch have to be Schedule, LoweredFunc, list of " + "LoweredFunc, or dict of target to list of " + "LoweredFunct") + + if not isinstance(sch, (dict, container.Map)): + target = _target.current_target() if target is None else target + target = target if target else "llvm" + target_flist = {target: flist} else: - raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") - fname_set = set() - for x in flist: - if not isinstance(x, container.LoweredFunc): - raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") - if x.name in fname_set: - raise ValueError("Duplicate function name %s" % x.name) - fname_set.add(x.name) - - target = _target.current_target() if target is None else target - target = _target.create(target) if target else _target.create("llvm") - device_type = ndarray.context(target.target_name, 0).device_type - - fhost = [] - fdevice = [] - for func in flist: - if not ir_pass.VerifyMemory(func, device_type): - raise ValueError( - "Direct host side access to device memory is detected in %s. " - "Did you forget to bind?" % func.name) - if func.func_type == container.LoweredFunc.MixedFunc: - if current_build_config().detect_global_barrier: - func = ir_pass.ThreadSync(func, "global") - func = ir_pass.ThreadSync(func, "shared") - func = ir_pass.ThreadSync(func, "warp") - warp_size = target.thread_warp_size - func = ir_pass.LowerThreadAllreduce(func, warp_size) - fsplits = [s for s in ir_pass.SplitHostDevice(func)] - fhost.append(fsplits[0]) - for x in fsplits[1:]: - fdevice.append(x) - elif func.func_type == container.LoweredFunc.HostFunc: - fhost.append(func) - elif func.func_type == container.LoweredFunc.DeviceFunc: - fdevice.append(func) - else: - raise ValueError("unknown function type %d" % func.func_type) + target_flist = sch + + for tar, flist in target_flist.items(): + if not isinstance(tar, (str, _target.Target)): + raise ValueError("The key of sch must be str or " + "_target.Target when sch is dict.") + fname_set = set() + for x in flist: + if not isinstance(x, container.LoweredFunc): + raise ValueError("sch have to be Schedule, LoweredFunc, list " + "of LoweredFunc, or dict of str to list of " + "LoweredFunc.") + if x.name in fname_set: + raise ValueError("Duplicate function name %s" % x.name) + fname_set.add(x.name) + + fhost_all = [] + device_modules = [] + for tar, flist in target_flist.items(): + tar = _target.create(tar) + device_type = ndarray.context(tar.target_name, 0).device_type + fhost = [] + fdevice = [] + for func in flist: + if not ir_pass.VerifyMemory(func, device_type): + raise ValueError( + "Direct host side access to device memory is detected in %s. " + "Did you forget to bind?" % func.name) + if func.func_type == container.LoweredFunc.MixedFunc: + if current_build_config().detect_global_barrier: + func = ir_pass.ThreadSync(func, "global") + func = ir_pass.ThreadSync(func, "shared") + func = ir_pass.ThreadSync(func, "warp") + warp_size = tar.thread_warp_size + func = ir_pass.LowerThreadAllreduce(func, warp_size) + fsplits = [s for s in ir_pass.SplitHostDevice(func)] + fhost.append(fsplits[0]) + for x in fsplits[1:]: + fdevice.append(x) + elif func.func_type == container.LoweredFunc.HostFunc: + fhost.append(func) + elif func.func_type == container.LoweredFunc.DeviceFunc: + fdevice.append(func) + else: + raise ValueError("unknown function type %d" % func.func_type) - for i, func in enumerate(fdevice): - warp_size = target.thread_warp_size - fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) + for i, func in enumerate(fdevice): + warp_size = tar.thread_warp_size + fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) - if "gpu" in target.keys and not fdevice: - warnings.warn( - "Specified target %s, but cannot find device code, did you do bind?" % target) + if "gpu" in tar.keys and not fdevice: + warnings.warn( + "Specified target %s, but cannot find device code, did you do bind?" % tar) - fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] - fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] + fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] + fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] - if not target_host: - if device_type == ndarray.cpu(0).device_type: - target_host = target - assert not fdevice - else: - target_host = "llvm" if module.enabled("llvm") else "stackvm" - target_host = _target.create(target_host) - target_device = target - fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice] - fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] - fhost = [ir_pass.CombineContextCall(x) for x in fhost] - - # Append fhost to the device module and return the updated module. All - # device modules will be imported to the host module after all of them are - # collected. - mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None - if postpone_host_codegen: - return fhost, mdev - - mhost = codegen.build_module(fhost, str(target_host)) - if fdevice: + if not target_host: + device_type = ndarray.context(tar.target_name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + assert not fdevice + else: + target_host = "llvm" if module.enabled("llvm") else "stackvm" + + target_host = _target.create(target_host) + fdevice = [ir_pass.LowerIntrin(x, tar.target_name) for x in fdevice] + fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] + fhost = [ir_pass.CombineContextCall(x) for x in fhost] + + # Save the current lowered functions of the host and the device module. + fhost_all += fhost + if fdevice: + mdev = codegen.build_module(fdevice, str(tar)) + device_modules.append(mdev) + + # Generate a unified host module. + mhost = codegen.build_module(fhost_all, str(target_host)) + + # Import all modules. + for mdev in device_modules: mhost.import_module(mdev) return mhost diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index b916ee285717..3272165f0b02 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -124,9 +124,6 @@ def check_device(device, target_device): schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add]) lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add") - host_funcs_add, lib_add = tvm.build(lower_add, target=target_device, - name="elemwise_add", - postpone_host_codegen=True) # Insert copy. Neither compute nor schedule is required for the copy # node. The compute will be performed at runtime which is just data @@ -142,16 +139,8 @@ def check_device(device, target_device): elemwise_sub], name="elemwise_sub") - host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host, - name="elemwise_sub", - postpone_host_codegen=True) - host_funcs = host_funcs_add + host_funcs_sub - mhost = tvm.codegen.build_module(host_funcs, target_host) - if lib_add: - mhost.import_module(lib_add) - if lib_sub: - mhost.import_module(lib_sub) - + target_flist = {target_device: [lower_add], target_host: [lower_sub]} + mhost = tvm.build(target_flist, target_host=target_host) ctx = [host_ctx, device_ctx] mod = graph_runtime.create(graph, mhost, ctx) params = {} @@ -338,10 +327,6 @@ def check_device(device, target_device): lower_add1 = tvm.lower( add_schedule1, [tensor_d, copy_sub_add, elemwise_add1], name="elemwise_add1") - host_funcs_add, lib_add = tvm.build([lower_add0, lower_add1], - target=target_device, - postpone_host_codegen=True) - # Create module for sub whose target is the host. tensor_c = tvm.placeholder(shape, name="C") elemwise_sub = tvm.compute(shape, lambda *i: copy_add_sub(*i) @@ -350,15 +335,10 @@ def check_device(device, target_device): lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c, elemwise_sub], name="elemwise_sub") - host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host, - postpone_host_codegen=True) - host_funcs = host_funcs_add + host_funcs_sub - mhost = tvm.codegen.build_module(host_funcs, target_host) - if lib_add: - mhost.import_module(lib_add) - if lib_sub: - mhost.import_module(lib_sub) + target_flist = {target_device: [lower_add0, lower_add1], target_host: + [lower_sub]} + mhost = tvm.build(target_flist, target_host=target_host) ctx = [host_ctx, device_ctx] params = {} params["A"] = tensor_a = np.random.uniform( From 8c05834b1f7113eca2cfbde219bd3ca158d68cbc Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 26 Sep 2018 16:15:47 +0000 Subject: [PATCH 2/3] only keep changes to tvm.build --- nnvm/include/nnvm/node.h | 6 --- nnvm/python/nnvm/compiler/build_module.py | 16 ++---- nnvm/src/compiler/graph_compile.cc | 59 ++++------------------- nnvm/src/compiler/graph_fuse.h | 5 -- nnvm/src/pass/plan_memory.cc | 8 +-- 5 files changed, 17 insertions(+), 77 deletions(-) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 1e7e5562402a..ae782f04965e 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -106,12 +106,6 @@ struct NodeAttrs { * stateful operators. */ std::vector > subgraphs; - /*! - * \brief Device information of the node. It indicates the device that this - * node should be executed. By default, the node is not annotated with any - * device. - */ - int device_type = 0; }; /*! diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index fc8f36143134..6fab4460b427 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -6,7 +6,6 @@ import tvm from tvm.contrib import graph_runtime -from tvm._ffi import runtime_ctypes from tvm import autotvm from . import graph_attr, graph_util from .. import graph as _graph @@ -119,10 +118,10 @@ def _lower(sch, inputs, func_name, graph): @tvm.register_func("nnvm.compiler.build_target") -def _build(target_funcs, target_host): +def _build(funcs, target, target_host): if target_host == "": target_host = None - return tvm.build(target_funcs, target_host=target_host) + return tvm.build(funcs, target=target, target_host=target_host) def _update_shape_dtype(shape, dtype, params): @@ -292,16 +291,7 @@ def build(graph, target=None, shape=None, dtype="float32", graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply("InferShape") graph = graph_attr.set_dtype_inputs(graph, dtype) - targets = [] - if isinstance(target, (str, tvm.target.Target)): - targets = [str(target)] - elif isinstance(target, dict): - device_types = [runtime_ctypes.STR2MASK(dev) for dev in - target.keys()] - graph._set_json_attr("device_type", device_types, "list_int") - targets = [str(tar) for tar in target.values()] - graph._set_json_attr("target", targets, "list_str") - + graph._set_json_attr("target", str(target), "str") if target_host is not None: graph._set_json_attr("target_host", str(target_host), "str") if cfg.pass_enabled("OpFusion"): diff --git a/nnvm/src/compiler/graph_compile.cc b/nnvm/src/compiler/graph_compile.cc index db286c5784c8..3316f3932e27 100644 --- a/nnvm/src/compiler/graph_compile.cc +++ b/nnvm/src/compiler/graph_compile.cc @@ -13,7 +13,6 @@ #include #include #include -#include #include #include @@ -78,20 +77,8 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { CHECK(g.HasAttr("fused_entry")) << "Fusion hasn't been applied yet."; FuseEntryVec fuse_entries = g.GetAttr("fused_entry"); + std::string target = g.GetAttr("target"); std::string target_host; - const std::vector& target = - g.GetAttr>("target"); - std::vector device_types{0}; - if (g.HasAttr("device_type")) { - device_types = g.GetAttr>("device_type"); - } - CHECK_EQ(target.size(), device_types.size()) - << "The number of compilation target doesn't match the given number of " - "devices."; - DeviceTargetMap dev_target_map; - for (size_t i = 0; i < device_types.size(); i++) { - dev_target_map.emplace(std::make_pair(device_types[i], target[i])); - } if (g.HasAttr("target_host")) { target_host = g.GetAttr("target_host"); @@ -100,7 +87,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { const nnvm::Op* assign_op = nnvm::Op::Get("_assign"); // Start lowering. - std::unordered_map> tar_func_map; + Array func_list; std::unordered_set func_set; const IndexedGraph& idx = g.indexed_graph(); @@ -108,10 +95,9 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; int root_id = group_vec[nid]; - if (static_cast(nid) != root_id) continue; + if (static_cast(nid) != root_id) continue; int master = master_vec[root_id]; FuseEntry& fe = fuse_entries[root_id]; - fe.device_type = inode.source->attrs.device_type; const IndexedGraph& subidx = fe.subgraph.indexed_graph(); CHECK_EQ(subidx.input_nodes().size(), fe.imap.size()); @@ -131,20 +117,17 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { } } CHECK_NE(sub_master_idx, -1) << "A master node not found in the subgraph."; - CHECK(dev_target_map.count(fe.device_type)) - << "Cannot find the compilation target for device " << fe.device_type; - const auto& cur_target = dev_target_map[fe.device_type]; - fe.compiled_func = - GraphLower(fe.subgraph, inputs, cur_target, sub_master_idx); + fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx); for (LoweredFunc f : fe.compiled_func->funcs) { if (!func_set.count(f.get())) { func_set.insert(f.get()); - tar_func_map[cur_target].push_back(f); + func_list.push_back(f); } } } const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op"); + std::unordered_map old_new; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; @@ -152,7 +135,6 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { // Only copy name since that is sufficient. nnvm::NodePtr np = nnvm::Node::Create(); np->attrs.name = inode.source->attrs.name; - np->attrs.device_type = inode.source->attrs.device_type; old_new[nid] = np; continue; } @@ -165,7 +147,6 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { nnvm::NodePtr np = nnvm::Node::Create(); np->attrs.op = tvm_op; np->attrs.name = inode.source->attrs.name; - np->attrs.device_type = inode.source->attrs.device_type; TVMOpParam param; param.func_name = fe.compiled_func->func_name; param.num_inputs = static_cast(fe.imap.size()); @@ -180,7 +161,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { auto rit = fe.reverse_imap.find(subidx[sub_input_id].source); CHECK(rit != fe.reverse_imap.end()); const IndexedGraph::NodeEntry& e = rit->second; - auto it = old_new.find(e.node_id); + auto it = old_new.find(e.node_id); CHECK(it != old_new.end()) << "cannot find node_id=" << e.node_id; np->inputs.emplace_back( @@ -193,7 +174,6 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { } old_new[nid] = np; } - nnvm::Graph ret; for (const auto& e : idx.outputs()) { auto it = old_new.find(group_vec[e.node_id]); @@ -217,7 +197,7 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { // // assign is a special operator that mutates the variable. // Currently assign is implemented as output = copy(input[1]) - // Then we run DecorateMemoryPlan to force + // Then we run DecorageMemoryPlan to force // output.storage = input[0].storage // std::vector assign_flag(new_idx.num_nodes(), 0); @@ -258,28 +238,9 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { ret.attrs["dtype"] = std::make_shared(std::move(new_dtype_vec)); ret.attrs["dltype"] = std::make_shared(std::move(new_dltype_vec)); - // Setup device assignment for heterogeneous execution. - if (tar_func_map.size() > 1) { - DeviceVector device_vec(new_idx.num_node_entries(), 0); - for (size_t i = 0; i < new_idx.num_nodes(); i++) { - device_vec[new_idx.entry_id(i, 0)] = - static_cast(new_idx[i].source->attrs.device_type); - } - for (uint32_t nid = 0; nid < new_idx.num_nodes(); nid++) { - const auto& inode = new_idx[nid]; - for (const auto& e : inode.inputs) { - device_vec[new_idx.entry_id(e)] = - static_cast(new_idx[e.node_id].source->attrs.device_type); - } - } - ret.attrs["device_index"] = std::make_shared(std::move(device_vec)); - } - // Setup module. + // Setup module static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target"); - tvm::runtime::Module module = - fbuild(tvm::Map>( - tar_func_map.begin(), tar_func_map.end()), - target_host); + tvm::runtime::Module module = fbuild(func_list, target, target_host); ret.attrs["module"] = std::make_shared(std::move(module)); ret = nnvm::ApplyPass(ret, "PlanMemory"); ret = DecorateMemoryPlan(ret, assign_flag); diff --git a/nnvm/src/compiler/graph_fuse.h b/nnvm/src/compiler/graph_fuse.h index abbb0257dbb0..6faac7d3e162 100644 --- a/nnvm/src/compiler/graph_fuse.h +++ b/nnvm/src/compiler/graph_fuse.h @@ -7,7 +7,6 @@ #define NNVM_COMPILER_GRAPH_FUSE_H_ #include -#include #include #include "compile_engine.h" @@ -61,8 +60,6 @@ struct FuseEntry { bool flatten_data; // The corresponding function. GraphFunc compiled_func; - // The device that the fused op lowered to. - int device_type; }; // GroupVec stores the root node ids of the fused nodes. @@ -77,8 +74,6 @@ using FuseEntryVec = std::vector; // PatternVec stores operator patterns. using PatternVec = std::vector; -// DeviceTargetMap stores the device type to compilation target mapping info. -using DeviceTargetMap = std::unordered_map; } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index f7ebcf5f2cf7..e0788386e6ea 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -187,8 +187,8 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, const DTypeVector& dtype_vec = ret.GetAttr("dtype"); const DeviceVector* device_vec = nullptr; - if (ret.attrs.count("device_index") != 0) { - device_vec = &(ret.GetAttr("device_index")); + if (ret.attrs.count("device") != 0) { + device_vec = &(ret.GetAttr("device")); } size_t num_not_allocated = 0; std::vector storage_ref_count(idx.num_node_entries(), 0); @@ -237,6 +237,8 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, } } } + // normal allocation + const int dev_id = (device_vec != nullptr) ? device_vec->at(nid) : 0; // sort output nodes based on size before allocating output std::multimap eids; for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { @@ -251,8 +253,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { uint32_t eid = rit->second; - // normal allocation - int dev_id = (device_vec != nullptr) ? device_vec->at(eid) : 0; auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); if (sid >= 0) { storage_ref_count[sid] = entry_ref_count[eid]; From df6bb16dbf3612a4e89e017cb6fde77f9d97d460 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 28 Sep 2018 04:54:10 +0000 Subject: [PATCH 3/3] separate function to build module for each device --- python/tvm/build_module.py | 227 ++++++++++++++++++++++--------------- 1 file changed, 135 insertions(+), 92 deletions(-) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 492ed97b7b02..2bb7442bab76 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -380,7 +380,81 @@ def lower(sch, return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) -def build(sch, +def _build_for_device(flist, target, target_host): + """Build the lowered functions for a device with the given compilation + target. + + Parameters + ---------- + flist : list of LoweredFunc + The schedule to be built. + + target : str or :any:`tvm.target.Target` + The target and option of the compilation. + + target_host : str or :any:`tvm.target.Target` + The host compilation target. + + Returns + ------- + fhost : list of LoweredFunc + A list of lowered functions for the host. + + mdev : tvm.module + A module that contains device code. + """ + target = _target.create(target) + device_type = ndarray.context(target.target_name, 0).device_type + fhost = [] + fdevice = [] + for func in flist: + if not ir_pass.VerifyMemory(func, device_type): + raise ValueError( + "Direct host side access to device memory is detected in %s. " + "Did you forget to bind?" % func.name) + if func.func_type == container.LoweredFunc.MixedFunc: + if current_build_config().detect_global_barrier: + func = ir_pass.ThreadSync(func, "global") + func = ir_pass.ThreadSync(func, "shared") + func = ir_pass.ThreadSync(func, "warp") + warp_size = target.thread_warp_size + func = ir_pass.LowerThreadAllreduce(func, warp_size) + fsplits = [s for s in ir_pass.SplitHostDevice(func)] + fhost.append(fsplits[0]) + for x in fsplits[1:]: + fdevice.append(x) + elif func.func_type == container.LoweredFunc.HostFunc: + fhost.append(func) + elif func.func_type == container.LoweredFunc.DeviceFunc: + fdevice.append(func) + else: + raise ValueError("unknown function type %d" % func.func_type) + + for i, func in enumerate(fdevice): + warp_size = target.thread_warp_size + fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) + + if "gpu" in target.keys and not fdevice: + warnings.warn( + "Specified target %s, but cannot find device code, did you do " + "bind?" % target) + + fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] + fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] + + if device_type == ndarray.cpu(0).device_type and target_host == target: + assert not fdevice + + target_host = _target.create(target_host) + fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] + fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] + fhost = [ir_pass.CombineContextCall(x) for x in fhost] + mdev = codegen.build_module(fdevice, str(target)) if fdevice else None + + return fhost, mdev + + +def build(inputs, args=None, target=None, target_host=None, @@ -389,29 +463,9 @@ def build(sch, """Build a function with arguments as signature. Code will be generated for devices coupled with target information. - There are two typical uses of this functioin: - 1. For backward compatibility - n = 2 - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - s = tvm.create_schedule(C.op) - f = tvm.lower(s, [A, B, C], name="test_add") - m = tvm.build(f, target="llvm") - - 2. For multi-device execution: - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - s1 = tvm.create_schedule(C.op) - s2 = topi.cpp.cuda.schedule_injective("cuda", [C]) - f1 = tvm.lower(s1, [A, B, C], name="test_add1") - f2 = tvm.lower(s2, [A, B, C], name="test_add2") - m = tvm.build({"llvm":[f1], "cuda":[f2]}, target_host="llvm") - Parameters ---------- - sch : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list + inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list The schedule to be built args : list of Buffer or Tensor or Var, optional @@ -438,117 +492,106 @@ def build(sch, Returns ------- - ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple + ret : tvm.module A module that combines both host and device code. + Examples + ________ + There are two typical example uses of this function depending on the type + of the argument `inputs`: + 1. it is a list of lowered functions: + + .. code-block:: python + + n = 2 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + f = tvm.lower(s, [A, B, C], name="test_add") + m = tvm.build(f, target="llvm") + + 2. it is a dict of compilation target to list of lowered functions: + + .. code-block:: python + + n = 2 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s1 = tvm.create_schedule(C.op) + s2 = topi.cpp.cuda.schedule_injective("cuda", [C]) + f1 = tvm.lower(s1, [A, B, C], name="test_add1") + f2 = tvm.lower(s2, [A, B, C], name="test_add2") + m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") + Note ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(sch, schedule.Schedule): + if isinstance(inputs, schedule.Schedule): if args is None: raise ValueError("args must be given for build from schedule") - flist = lower(sch, args, + flist = lower(inputs, args, name=name, binds=binds) if isinstance(flist, container.LoweredFunc): flist = [flist] - elif isinstance(sch, container.LoweredFunc): + elif isinstance(inputs, container.LoweredFunc): if args: - raise ValueError("args must be done when build from LoweredFunc") - flist = [sch] - elif isinstance(sch, (list, tuple, container.Array)): - flist = sch - elif not isinstance(sch, (dict, container.Map)): - raise ValueError("sch have to be Schedule, LoweredFunc, list of " + raise ValueError("args must be done when build from LoweredFunc.") + flist = [inputs] + elif isinstance(inputs, (list, tuple, container.Array)): + flist = inputs + elif not isinstance(inputs, (dict, container.Map)): + raise ValueError("inputs must be Schedule, LoweredFunc, list of " "LoweredFunc, or dict of target to list of " - "LoweredFunct") + "LoweredFunc.") - if not isinstance(sch, (dict, container.Map)): + if not isinstance(inputs, (dict, container.Map)): target = _target.current_target() if target is None else target target = target if target else "llvm" target_flist = {target: flist} else: - target_flist = sch + target_flist = inputs for tar, flist in target_flist.items(): if not isinstance(tar, (str, _target.Target)): - raise ValueError("The key of sch must be str or " - "_target.Target when sch is dict.") + raise ValueError("The key of inputs must be str or " + "_target.Target when inputs is dict.") fname_set = set() for x in flist: if not isinstance(x, container.LoweredFunc): - raise ValueError("sch have to be Schedule, LoweredFunc, list " + raise ValueError("inputs must be Schedule, LoweredFunc, list " "of LoweredFunc, or dict of str to list of " "LoweredFunc.") if x.name in fname_set: raise ValueError("Duplicate function name %s" % x.name) fname_set.add(x.name) - fhost_all = [] - device_modules = [] - for tar, flist in target_flist.items(): - tar = _target.create(tar) - device_type = ndarray.context(tar.target_name, 0).device_type - fhost = [] - fdevice = [] - for func in flist: - if not ir_pass.VerifyMemory(func, device_type): - raise ValueError( - "Direct host side access to device memory is detected in %s. " - "Did you forget to bind?" % func.name) - if func.func_type == container.LoweredFunc.MixedFunc: - if current_build_config().detect_global_barrier: - func = ir_pass.ThreadSync(func, "global") - func = ir_pass.ThreadSync(func, "shared") - func = ir_pass.ThreadSync(func, "warp") - warp_size = tar.thread_warp_size - func = ir_pass.LowerThreadAllreduce(func, warp_size) - fsplits = [s for s in ir_pass.SplitHostDevice(func)] - fhost.append(fsplits[0]) - for x in fsplits[1:]: - fdevice.append(x) - elif func.func_type == container.LoweredFunc.HostFunc: - fhost.append(func) - elif func.func_type == container.LoweredFunc.DeviceFunc: - fdevice.append(func) - else: - raise ValueError("unknown function type %d" % func.func_type) - - for i, func in enumerate(fdevice): - warp_size = tar.thread_warp_size - fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) - - if "gpu" in tar.keys and not fdevice: - warnings.warn( - "Specified target %s, but cannot find device code, did you do bind?" % tar) - - fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] - fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] - - if not target_host: + if not target_host: + for tar, _ in target_flist.items(): + tar = _target.create(tar) device_type = ndarray.context(tar.target_name, 0).device_type if device_type == ndarray.cpu(0).device_type: target_host = tar - assert not fdevice - else: - target_host = "llvm" if module.enabled("llvm") else "stackvm" - - target_host = _target.create(target_host) - fdevice = [ir_pass.LowerIntrin(x, tar.target_name) for x in fdevice] - fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] - fhost = [ir_pass.CombineContextCall(x) for x in fhost] + break + if not target_host: + target_host = "llvm" if module.enabled("llvm") else "stackvm" + fhost_all = [] + device_modules = [] + for tar, flist in target_flist.items(): + fhost, mdev = _build_for_device(flist, tar, target_host) # Save the current lowered functions of the host and the device module. fhost_all += fhost - if fdevice: - mdev = codegen.build_module(fdevice, str(tar)) - device_modules.append(mdev) + device_modules.append(mdev) # Generate a unified host module. mhost = codegen.build_module(fhost_all, str(target_host)) # Import all modules. for mdev in device_modules: - mhost.import_module(mdev) + if mdev: + mhost.import_module(mdev) return mhost