From abc80d57f48f4bdf4625a17a53a93adb0d90bbff Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 Jul 2020 16:39:51 +0000 Subject: [PATCH 01/21] vm heterogeneous execution --- include/tvm/runtime/vm/bytecode.h | 23 +- include/tvm/runtime/vm/executable.h | 2 + include/tvm/runtime/vm/vm.h | 11 +- python/tvm/relay/analysis/__init__.py | 2 + python/tvm/relay/analysis/context_analysis.py | 423 ++++++++++++++++++ python/tvm/relay/backend/vm.py | 7 - python/tvm/relay/op/_tensor.py | 2 + python/tvm/relay/op/_transform.py | 14 +- python/tvm/relay/transform/memory_alloc.py | 116 ++++- python/tvm/relay/transform/memory_plan.py | 8 + python/tvm/runtime/vm.py | 12 +- src/relay/backend/vm/compiler.cc | 186 ++++++-- src/relay/backend/vm/compiler.h | 2 + src/relay/op/device_copy.cc | 8 +- src/relay/transforms/fold_constant.cc | 4 +- src/runtime/vm/bytecode.cc | 37 +- src/runtime/vm/executable.cc | 17 +- src/runtime/vm/profiler/vm.cc | 9 +- src/runtime/vm/vm.cc | 49 +- tests/python/relay/test_any.py | 213 +++------ tests/python/relay/test_pass_annotation.py | 115 +++-- .../relay/test_pass_context_analysis.py | 99 ++++ tests/python/relay/test_vm.py | 2 +- 23 files changed, 1077 insertions(+), 284 deletions(-) create mode 100644 python/tvm/relay/analysis/context_analysis.py create mode 100644 tests/python/relay/test_pass_context_analysis.py diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 89a3164f7483..cb9a59a9ab93 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -66,6 +66,7 @@ enum class Opcode { AllocStorage = 16U, ShapeOf = 17U, ReshapeTensor = 18U, + DeviceCopy = 19U, }; /*! \brief A single virtual machine instruction. @@ -196,6 +197,8 @@ struct Instruction { Index alignment; /*! \brief The hint of the dtype. */ DLDataType dtype_hint; + /*! \brief The device type of the allocation. */ + Index device_type; } alloc_storage; struct /* ShapeOf Operands */ { RegName tensor; @@ -204,6 +207,13 @@ struct Instruction { RegName tensor; RegName newshape; } reshape_tensor; + struct /* DeviceCopy Operands */ { + RegName src; + /*! \brief The source device type. */ + Index src_device_type; + /*! \brief The destination device type. */ + Index dst_device_type; + }; }; /*! @@ -341,11 +351,12 @@ struct Instruction { * \param size The size of the allocation. * \param alignment The allocation's alignment. * \param dtype_hint The data type hint for the allocator. + * \param device_type The device type for the allocator. * \param dst The destination to place the storage. * \return The alloc storage instruction. */ static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - RegName dst); + Index device_type, RegName dst); /*! * \brief Get the shape of an input tensor. * \param tensor The input tensor. @@ -361,6 +372,16 @@ struct Instruction { * \return The reshape tensor instruction. */ static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst); + /*! + * \brief Copy tensor cross different devices. + * \param src The source register. + * \param src_device_type The device type of the tensor for the source register. + * \param dst_device_type The device type of the tensor ofr the destination register. + * \param dst The destination register to store the copied tensor. + * \return The reshape tensor instruction. + */ + static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + RegName dst); Instruction(); Instruction(const Instruction& instr); diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index cc38da75a0c7..8d3f651758d1 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -161,6 +161,8 @@ class Executable : public ModuleNode { std::unordered_map primitive_map; /*! \brief The virtual machine's function table. */ std::vector functions; + /*! \brief The device type for each constant. */ + std::vector const_device_type; private: /*! diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 273b8fe60847..ba2585b49c3d 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -83,13 +83,17 @@ struct VMFunction { std::vector instructions; /*! \brief The size of the frame for this function */ Index register_file_size; + /*! \brief The device type of each parameter for this function. */ + std::vector params_device_type; VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, Index register_file_size) + const std::vector& instructions, Index register_file_size, + const std::vector params_device_type = {}) : name(name), params(params), instructions(instructions), - register_file_size(register_file_size) {} + register_file_size(register_file_size), + params_device_type(params_device_type) {} VMFunction() {} @@ -247,6 +251,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief Get device context for params. */ TVMContext GetParamsContext() const; + /*! \brief Get context from the context list based on a given device type. */ + TVMContext GetContext(Index device_type) const; + /*! * \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index e5b21cb107f5..0c065ef07a6e 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -29,3 +29,5 @@ # Feature from . import feature from . import sparse_dense + +from . import context_analysis diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py new file mode 100644 index 000000000000..ac7a146312f7 --- /dev/null +++ b/python/tvm/relay/analysis/context_analysis.py @@ -0,0 +1,423 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +""" +A pass for analyzing device attribute of each IR node. +""" +from typing import Optional +from collections import defaultdict + +from ..expr_functor import ExprVisitor +from ..function import Function +from .. import op, expr as _expr +from ... import register_func, cpu +from ..._ffi.runtime_ctypes import TVMContext + +def is_device_copy(call): + """Check if a call node is a device copy call. + Parameters + ---------- + call : tvm.relay.Call + The call node to be checked. + Returns + ------- + ret : Boolean + True if the call is a device copy call. Otherwise, false. + """ + if not isinstance(call, _expr.Call): + return False + if call.op == op.op.get("device_copy"): + return True + + if not isinstance(call.op, Function): + return False + return isinstance(call.op.body, _expr.Call) and \ + call.op.body.op == op.op.get("device_copy") + +class DeviceDomain: + """A class to represent the device of a domain, i.e. a segment of relay + program. + Parameters + ---------- + ctx : Optional[tvm.runtime.TVMContext] + The device to be assigned to the current domain. It is optional. + """ + def __init__(self, ctx: Optional[TVMContext]): + self.domain = ctx + + def join(self, other: 'DeviceDomain') -> 'DeviceDomain': + """Merge the device of two domains. + Parameters + ---------- + other : DeviceDomain + The other domain to be merged. + Returns + ------- + ret : DeviceDomain + The merged domain. An error will be raised if two domain has + conflict, i.e. they have different context. + """ + if self.domain is None and other.domain is None: + return self + elif self.domain is None: + return other + elif other.domain is None: + return self + elif (self.domain.device_type == other.domain.device_type and + self.domain.device_id == other.domain.device_id): + return self + else: + raise Exception("all expressions must have a singular device") + + def __hash__(self): + if self.domain is None: + return id(self) + else: + return hash((self.domain.device_type, self.domain.device_id)) + + def __eq__(self, other): + if self.domain is None and other.domain is None: + return id(self) == id(other) + else: + return self.domain == other.domain + + +def bottom(): + """Create an empty domain. This would usually happen when we enter a new + scope, i.e. Function. + """ + return DeviceDomain(None) + + +def device_type(ctx): + """Create a domain with the given device context. + Parameters + ---------- + ctx : tvm.runtime.TVMContext + The device context used to construct a domain. + Returns + ------- + ret : DeviceDomain + The constructed domain. + """ + return DeviceDomain(ctx) + + +class ContextAnalysis(ExprVisitor): + """Compute on which device each sub-expression will execute. A union find + algorithm is used to assign and merge the context domains. + Parameters + ---------- + fallback_device : tvm.rutnime.TVMContext + The default device that could be attached to an expression. + """ + def __init__(self, fallback_device): + super().__init__() + self.expr_to_device = defaultdict(bottom) + self.device_uf = {} + self.fallback_device = fallback_device + + def lookup(self, device): + """Find the root domain of a given device domain. + Parameters + ---------- + device : DeviceDomain + The domain that is used to query the root domain. + Returns + ------- + ret : DeviceDomain + The root domain. + """ + while device in self.device_uf: + device = self.device_uf[device] + return device + + def unify(self, lhs, rhs): + """Unify the device context of two domains. + Parameters + ---------- + lhs : DeviceDomain + The lhs domain to unify. + rhs : DeviceDomain + The rhs domain to unify. + Returns + ------- + ret : DeviceDomain + The unified domain. + """ + lhs = self.lookup(lhs) + rhs = self.lookup(rhs) + unified_device = lhs.join(rhs) + if not lhs == unified_device: + self.device_uf[lhs] = unified_device + if not rhs == unified_device: + self.device_uf[rhs] = unified_device + return unified_device + + def unify_expr(self, lhs, rhs): + """Compute the device type of both expressions and unify them. + Parameters + ---------- + lhs : tvm.relay.Expr + The lhs expression to unify. + rhs : tvm.relay.Expr + The rhs expression to unify. + Returns + ------- + ret : DeviceDomain + The unified domain. + """ + return self.unify(self.device_for(lhs), self.device_for(rhs)) + + def device_for(self, expr): + """Find the domain that contains the given expr. + Parameters + ---------- + expr : tvm.relay.Expr + The expression used to lookup a domain. + Returns + ------- + ret : DeviceDomain + The domain that contains the expression. + """ + return self.lookup(self.expr_to_device[expr]) + + def device_copy(self, inps, outputs, src_dev_type, dst_dev_type): + """Unify the device context for device copy node. Device copy node is + the only node that carries information in the input program. The device + attribute of other nodes are propagated from it. + Parameters + ---------- + inps : List[tvm.relay.Expr] + The input expression to the device copy node. The device type of + the input should be the same as the source device type of the + copy node. + outputs : List[tvm.relay.Expr] + The output expression of the device copy node. The device type of + the output should be the same as the destination device type of the + copy node. + src_dev_type : int + The source device type of the copy node. + dst_dev_type : int + The destination device type of the copy node. + """ + src_dev_type = device_type(TVMContext(src_dev_type, 0)) + for inp in inps: + self.unify(self.device_for(inp), src_dev_type) + + dst_dev_type = device_type(TVMContext(dst_dev_type, 0)) + for output in outputs: + self.unify(self.device_for(output), dst_dev_type) + + def unify_call(self, call_op, inputs, outputs, device=None): + """Unify the domain of inputs and outputs of a relay Call. + Parameters + ---------- + op : tvm.relay.Expr + The op of a call node. + inputs : List[tvm.relay.Expr] + The inputs of the call. + outputs : List[tvm.relay.Expr] + The outputs of the call. + Returns + ------- + The unified domain. + Note + ---- + For most call nodes, the op, inputs, and outputs should all be in the + same domain, i.e. have the same context. However, device_copy call node + needs to be handled different as it copies data from one device to + another. + """ + device = device if device else bottom() + for arg in inputs: + device = self.unify(device, self.device_for(arg)) + + device = self.unify(device, self.device_for(call_op)) + + for out in outputs: + device = self.unify(device, self.device_for(out)) + + return device + + def visit_call(self, call): + if is_device_copy(call): + inps = [call.args[0]] + outs = [call] + if isinstance(call.op, Function): + # device_copy is fused, propagate device to the fused function + inps.append(call.op.params[0]) + outs.append(call.op) + body = call.op.body + assert isinstance(body, _expr.Call) and is_device_copy(body) + outs.append(call.op.body) + src_dev_type = call.op.body.attrs.src_dev_type + dst_dev_type = call.op.body.attrs.dst_dev_type + else: + src_dev_type = call.attrs.src_dev_type + dst_dev_type = call.attrs.dst_dev_type + + # Device copy op only has one input which is now annotated with the + # same device to the source device type of the device copy op. + # The call itself has the same device type to the destination. + self.device_copy(inps, outs, src_dev_type, dst_dev_type) + super().visit_call(call) + elif call.op == op.op.get("memory.alloc_storage"): + call_dev = device_type(TVMContext(call.attrs.device_type, + call.attrs.device_id)) + self.unify(self.device_for(call), call_dev) + # The arguments should be one the same device as the call. + self.visit(call.args[0]) + size = call.args[0] + self.visit(call.args[1]) + alignment = call.args[1] + self.unify(self.device_for(size), call_dev) + self.unify(self.device_for(alignment), call_dev) + elif call.op == op.op.get("memory.alloc_tensor"): + storage = call.args[0] + shape = call.args[1] + self.visit(call.args[1]) + self.unify(self.device_for(storage), self.device_for(call)) + self.unify(self.device_for(shape), self.device_for(call)) + elif call.op == op.op.get("vm.shape_func"): + shape_func_domain = device_type(cpu(0)) + # No need to union the op of a shape_func as shape_func doesn't + # invoke the op itself. It should be handled by invoke_tvm_op. + # Therefore, we skip call.args[0] here. + self.unify_call(call, call.args[1].fields, + call.args[2].fields, shape_func_domain) + for arg in call.args[1]: + self.visit(arg) + for arg in call.args[2]: + self.visit(arg) + elif call.op == op.op.get("vm.invoke_tvm_op"): + if isinstance(call.args[0].body, _expr.Call) and \ + call.args[0].body.op == op.op.get("device_copy"): + input_tensor = call.args[1] + output_tensor = call.args[2] + self.device_copy(input_tensor, output_tensor, + call.attrs.src_dev_type, + call.attrs.dst_dev_type) + else: + device = self.unify_call(call.args[0], call.args[1].fields, + call.args[2].fields) + self.unify(self.device_for(call), device) + super().visit_call(call) + elif isinstance(call.op, Function): + device = self.device_for(call) + for arg in call.args: + device = self.unify(device, self.device_for(arg)) + self.visit(arg) + + for param in call.op.params: + self.visit(param) + device = self.unify(device, self.device_for(param)) + + self.unify(device, self.device_for(call.op)) + self.unify(device, self.device_for(call.op.body)) + self.visit(call.op) + else: + self.unify_call(call, call.args, [call]) + super().visit_call(call) + + def visit_let(self, let): + while isinstance(let, _expr.Let): + self.unify(self.device_for(let.var), self.device_for(let.value)) + self.unify_expr(let, let.body) + self.visit(let.var) + self.visit(let.value) + let = let.body + + def visit_function(self, f): + self.unify(self.device_for(f), self.device_for(f.body)) + super().visit_function(f) + + def visit_tuple(self, tup): + # We only support tuple with the same of device. + device = self.device_for(tup[0]) + for i in range(1, len(tup)): + device = self.unify(device, self.device_for(tup[i])) + self.unify(device, self.device_for(tup)) + super().visit_tuple(tup) + + def visit_tuple_getitem(self, t): + value = t.tuple_value + if isinstance(t.tuple_value, _expr.Tuple): + value = t.tuple_value[t.index] + self.unify(self.device_for(t), self.device_for(value)) + super().visit_tuple_getitem(t) + + def visit_var(self, var): + self.device_for(var) + + def visit_constant(self, const): + self.device_for(const) + + def results(self): + """Return the analysis result. + Returns + ------- + ret : Dict[tvm.relay.Expr, DeviceDomain] + The dictionary mapping each expression to a device context. + """ + results = {} + for exp in self.expr_to_device: + device = self.lookup(self.expr_to_device[exp]) + if device.domain is None: + results[exp] = self.fallback_device + else: + results[exp] = device.domain + + return results + + +def mk_analysis_annotator(results): + """Pretty print the annotated relay program with device info""" + def _annotator(exp): + if exp in results: + return f"<{results[exp]}>" + else: + return "" + + return _annotator + + +def context_analysis(expr, fallback_device): + """Perform device context analysis on a given relay program. This requires + that the program has already been annotated and rewritten by replacing on + device annotations with device copy nodes. + Parameters + ---------- + expr : tvm.relay.Expr + The expression for analysis + fallback_device : tvm.runtime.TVMContext + The default device context + Returns + ------- + ret : Dict[tvm.relay.Expr, [int]] + The mapping of each expression to the device context that is + represented in a list form as TVMContext is not a runtime object. + """ + ca = ContextAnalysis(fallback_device) + ca.visit(expr) + ret = defaultdict(list) + for key, val in ca.results().items(): + ret[key] = [val.device_type, val.device_id] + return ret + + +register_func("relay.analysis.ContextAnalysis", context_analysis) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 73b0d22804bd..656652c23004 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -27,7 +27,6 @@ import tvm.runtime.vm as vm_rt from tvm import autotvm from tvm.relay import expr as _expr -from tvm.relay.ty import is_dynamic from tvm.relay.backend.interpreter import Executor from . import _vm @@ -261,12 +260,6 @@ def _make_executor(self, expr=None): def _vm_wrapper(*args, **kwargs): args = self._convert_args(main, args, kwargs) - ret_type = self.mod["main"].checked_type.ret_type - if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str( - self.target): - raise ValueError( - "Virtual Machine only supports dynamic graphs on CPU, got output type", - ret_type, "on target", self.target) return self.vm.run(*args) return _vm_wrapper diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index eccc2c3c5f15..46434461d429 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -87,6 +87,7 @@ register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_erf") +register_broadcast_schedule("device_copy") # zeros @@ -241,3 +242,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("fast_erf", False, elemwise_shape_func) register_shape_func("floor", False, elemwise_shape_func) register_shape_func("log", False, elemwise_shape_func) +register_shape_func("device_copy", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 937c36e60919..70f56a489314 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -217,7 +217,7 @@ def concatenate_shape_func(attrs, inputs, _): return [_concatenate_shape_func(inputs, convert(axis))] @script -def _reshape_shape_func_input_shape(data_shape, newshape, ndim): +def _reshape_shape_func_input_shape(data_shape, newshape, ndim, reverse=True): out = output_tensor((ndim,), "int64") src_idx = 0 dst_idx = 0 @@ -677,3 +677,15 @@ def split_shape_func(attrs, inputs, _): convert(i), convert(indices_or_sections), convert(axis)) for i in range(num_out)] + +@_reg.register_shape_func("contrib_reverse_reshape", False) +def contrib_reverse_reshape_shape_func(attrs, inputs, out_ndims): + newshape = get_const_tuple(attrs.newshape) + print(inputs[0]) + data_shape = reversed(inputs[0]) + newshape = reversed(newshape) + return [_reshape_shape_func_input_shape(data_shape, + convert(newshape), + out_ndims[0])] + + diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index ae7db3384214..106fb742d749 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,6 +19,11 @@ A pass for manifesting explicit memory allocations. """ import numpy as np +import logging + +from tvm.ir.transform import PassContext +from tvm import nd, container, tir +from ..function import Function from ..expr_functor import ExprVisitor, ExprMutator from ..scope_builder import ScopeBuilder from . import transform @@ -29,16 +34,32 @@ from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu from ..op.memory import alloc_storage +from ..analysis.context_analysis import ContextAnalysis, mk_analysis_annotator +from ..._ffi.runtime_ctypes import TVMContext def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) + def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 +def is_device_copy(func): + """ + Check if the current relay expression is shape_of call. We can simply check + the body of it if it is a function becase the shape_of op is opaque. + """ + if isinstance(func, Function): + body = func.body + return isinstance(body, expr.Call) and body.op == op.get("device_copy") + if isinstance(func, expr.Call): + return body.op == op.get("device_copy") + return False + + class CheckReshapeOnly(ExprVisitor): """A pass to check if the fused op contains only reshape ops.""" def __init__(self): @@ -66,7 +87,7 @@ def is_reshape_only(func): class ManifestAllocPass(ExprMutator): """A pass for explicitly manifesting all memory allocations in Relay.""" - def __init__(self, target_host): + def __init__(self, target_host, context_analysis): self.invoke_tvm = op.vm.invoke_tvm_op self.shape_func = op.vm.shape_func self.shape_of = op.vm.shape_of @@ -75,8 +96,18 @@ def __init__(self, target_host): self.target_host = target_host self.default_context = cpu(0) self.compute_dtype = "int64" + self.context_analysis = context_analysis super().__init__() + def get_context(self, expr): + assert expr in self.context_analysis, expr.astext(False) + return self.context_analysis[expr] + + def device_copy(self, scope, inp, src_ctx, dst_ctx, idx): + copy = self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) + copy_out = scope.let("copy_out_{0}".format(idx), copy) + return copy_out + def current_scope(self): return self.scopes[-1] @@ -116,7 +147,7 @@ def compute_storage(self, tensor_type): size *= (dtype.bits * dtype.lanes + 7) // 8 return expr.const(size, dtype=self.compute_dtype) - def make_static_allocation(self, scope, tensor_type, i): + def make_static_allocation(self, scope, tensor_type, ctx, name_hint): """Allocate a tensor with a statically known shape.""" shape = [int(sh) for sh in tensor_type.shape] if len(shape) == 0: @@ -126,11 +157,13 @@ def make_static_allocation(self, scope, tensor_type, i): size = self.compute_storage(tensor_type) alignment = self.compute_alignment(tensor_type.dtype) dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(i), alloc_storage( - size, alignment, self.default_context, dtype)) + sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, + alignment, + ctx, + dtype)) # TODO(@jroesch): There is a bug with typing based on the constant shape. tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) - return scope.let("tensor_{0}".format(i), tensor) + return scope.let("tensor_{0}".format(name_hint), tensor) def visit_let(self, let): scope = ScopeBuilder() @@ -156,13 +189,17 @@ def emit_shape_func(self, scope, func, new_args): is_inputs = [] input_pos = 0 + cpu_ctx = nd.cpu(0) for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): + ctx = self.get_context(subexp) + if ctx.device_type != cpu_ctx.device_type: + subexp = self.device_copy(scope, subexp, ctx, cpu_ctx, j) let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), subexp) - sh_of = self.visit(self.shape_of(let_in_arg)) + sh_of = self.visit(self.shape_of(subexp)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) input_pos += 1 @@ -170,6 +207,9 @@ def emit_shape_func(self, scope, func, new_args): # Pass Inputs elif state == 1: new_arg = self.visit(arg) + ctx = self.get_context(arg) + if ctx.device_type != cpu_ctx.device_type: + new_arg = self.device_copy(scope, new_arg, ctx, cpu_ctx, i) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos), new_arg)) input_pos += 1 @@ -181,7 +221,7 @@ def emit_shape_func(self, scope, func, new_args): out_shapes = [] for i, out in enumerate(cfunc.outputs): tt = ty.TensorType(out.shape, out.dtype) - alloc = self.make_static_allocation(scope, tt, i) + alloc = self.make_static_allocation(scope, tt, cpu_ctx, i) alloc = scope.let("shape_func_out_{0}".format(i), alloc) out_shapes.append(alloc) @@ -198,16 +238,22 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) storages = [] + cpu_ctx = nd.cpu(0) + func_ctx = self.get_context(func) + copy_out_shapes = [] for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)): - size = self.compute_storage_in_relay( - out_shape, out_type.dtype) + size = self.compute_storage_in_relay(out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) + if func_ctx.device_type != cpu_ctx.device_type: + size = self.device_copy(scope, size, cpu_ctx, func_ctx, i) + out_shape = self.device_copy(scope, out_shape, cpu_ctx, func_ctx, i) + copy_out_shapes.append(out_shape) sto = scope.let("storage_{i}".format(i=i), alloc_storage( - size, alignment, self.default_context, out_type.dtype)) + size, alignment, func_ctx, out_type.dtype)) storages.append(sto) outs = [] - sh_ty_storage = zip(out_shapes, out_types, storages) + sh_ty_storage = zip(copy_out_shapes, out_types, storages) for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): alloc = alloc_tensor( storage, @@ -226,11 +272,19 @@ def emit_reshape_tensor(self, scope, func, new_args, ret_type): if self.is_dynamic(ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) shape_expr = out_shapes[0] + inp = new_args[0] + inp_ctx = self.get_context(func) + cpu_ctx = nd.cpu(0) + if inp_ctx.device_type != cpu_ctx.device_type: + shape_expr = self.device_copy(scope, shape_expr, cpu_ctx, + inp_ctx, 0) + ret = self.reshape_tensor(inp, shape_expr, ret_type.shape) + return ret else: # constant output shape shape = [int(dim) for dim in ret_type.shape] shape_expr = expr.const(shape, dtype=self.compute_dtype) - return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) + return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) def is_dynamic(self, ret_type): is_dynamic = ty.is_dynamic(ret_type) @@ -253,6 +307,15 @@ def visit_call(self, call): # Handle fused op that only contains reshape op return self.emit_reshape_tensor(scope, call.op, new_args, ret_type) + if is_device_copy(call.op): + # Handle device copy op + if isinstance(call.op, Function): + attr = call.op.body.attrs + else: + attr = call.attr + return op.tensor.device_copy(new_args[0], + TVMContext(attr.src_dev_type, 0), + TVMContext(attr.dst_dev_type, 0)) if self.is_dynamic(ret_type): # Handle dynamic case. return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) @@ -260,7 +323,9 @@ def visit_call(self, call): # Handle static case. outs = [] for i, out_ty in enumerate(out_types): - out = self.make_static_allocation(scope, out_ty, i) + ctx = self.get_context(call) + assert isinstance(ctx, TVMContext) + out = self.make_static_allocation(scope, out_ty, ctx, i) outs.append(out) output = expr.Tuple(outs) @@ -273,14 +338,35 @@ def visit_call(self, call): @transform.function_pass(opt_level=0) class ManifestAlloc: """The explicit pass wrapper around ManifestAlloc.""" - def __init__(self, target_host): + def __init__(self, target_host, targets): self.target_host = target_host + self.targets = targets def transform_function(self, func, mod, _): # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") - ea = ManifestAllocPass(self.target_host) + + assert isinstance(self.targets, (dict, container.Map)) + if len(self.targets) > 1: + pass_ctx = PassContext.current() + if "relay.fallback_device_type" in pass_ctx.config: + fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) + else: + fallback_ctx = cpu(0) + ca = ContextAnalysis(TVMContext(fallback_ctx.device_type, 0)) + else: + dev, _ = self.targets.items()[0] + ca = ContextAnalysis(nd.context(dev.value)) + + # We use logger here to help debug. + logging.debug("-----BEFORE ANALYSIS-----") + logging.debug(func.astext(False)) + ca.visit(func) + logging.debug("-----AFTER ANALYSIS-----") + logging.debug(func.astext(show_meta_data=False, + annotate=mk_analysis_annotator(ca.results()))) + ea = ManifestAllocPass(self.target_host, ca.results()) func = ea.visit(func) return func diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 8f21af9292a9..d93ee03bdfd8 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -84,6 +84,10 @@ def grow( self.alignment = alignment if self.ctx: + print(self.ctx.device_type) + print(ctx.device_type) + print(self.ctx.device_id) + print(ctx.device_id) assert (self.ctx.device_type == ctx.device_type and self.ctx.device_id == ctx.device_id), "must have matching context" else: @@ -282,6 +286,10 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): dynamic_regions.append(lhs) region = self.current_region(dtype) + if region.ctx and (region.ctx.device_type != ctx.device_type or \ + region.ctx.device_id != ctx.device_id): + return lhs, call + region.grow(lhs, size, alignment, ctx, dtype) return lhs, region.var diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index f88f43d838d5..4bfc5fd0a8a5 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -307,8 +307,14 @@ def __init__(self, exe, ctx, memory_cfg=None): def _setup_ctx(self, ctx, memory_cfg): """Init context and allocators.""" - if isinstance(ctx, tvm.runtime.TVMContext): - ctx = [ctx] + ctxs = ctx + if not isinstance(ctx, (list, tuple)): + assert isinstance(ctx, tvm.runtime.TVMContext) + ctxs = [ctx] + # CPU is required for executing shape functions + if ctx.device_type != tvm.cpu(0).device_type: + ctxs.append(tvm.cpu()) + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR if memory_cfg is None: memory_cfg = {} @@ -321,7 +327,7 @@ def _setup_ctx(self, ctx, memory_cfg): raise TypeError("memory_cfg is expected be string or dictionary, " + "but received {}".format(type(memory_cfg))) init_args = [] - for context in ctx: + for context in ctxs: init_args.append(context.device_type) init_args.append(context.device_id) alloc_type = memory_cfg[context] if context in memory_cfg else default_alloc_type diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 33854f783d45..b89911d60227 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -51,15 +52,33 @@ namespace tvm { namespace relay { +using ExprDeviceMap = std::unordered_map; + namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass ManifestAlloc(Target target_host) { +Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); CHECK(f != nullptr) << "unable to load allocation manifestation pass"; - return (*f)(target_host); + return (*f)(target_host, targets); +} + +ExprDeviceMap ContextAnalysis(Expr expr, TVMContext default_device) { + auto f = tvm::runtime::Registry::Get("relay.analysis.ContextAnalysis"); + CHECK(f != nullptr) << "could not load context analysis pass"; + Map > m = (*f)(expr, default_device); + ExprDeviceMap ret; + for (const auto& it : m) { + TVMContext ctx; + Array ints = it.second; + CHECK_EQ(ints.size(), 2U); + ctx.device_type = static_cast(ints[0]->value); + ctx.device_id = static_cast(ints[1]->value); + ret[it.first] = ctx; + } + return ret; } Pass MemoryPlan() { @@ -228,6 +247,27 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } +/*! + * \brief Create a default type. + * \param device_type The device type index. + * \return the default target for the device. + */ +Target CreateDefaultTarget(int device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") return Target::Create("llvm"); + if (name == "gpu") return Target::Create("cuda"); + return Target::Create(name); +} + +int GetFallbackDevice() { + transform::PassContext pass_ctx = PassContext::Current(); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); + auto fallback_dev = opt_fallback_dev.value(); + CHECK_GT(fallback_dev->value, 0U); + return fallback_dev->value; +} + class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -235,10 +275,29 @@ class VMFunctionCompiler : ExprFunctor { registers_num_(0), engine_(CompileEngine::Global()), context_(context), - targets_(targets), - target_host_(target_host) {} + target_host_(target_host) { + for (const auto& it : targets) { + targets_[it.first->value] = it.second; + } + } VMFunction Compile(const GlobalVar& var, const Function& func) { + // Collect the annotated device information. + // This indicates which device each Relay expr should be executed on. + TVMContext default_device; + if (targets_.size() > 1) { + int fallback_dev = GetFallbackDevice(); + default_device.device_type = static_cast(fallback_dev); + default_device.device_id = 0; + expr_device_map_ = transform::ContextAnalysis(func, default_device); + } else { + default_device.device_type = static_cast((targets_.begin())->first); + if (default_device.device_type != kDLCPU) { + default_device.device_id = 0; + expr_device_map_ = transform::ContextAnalysis(func, default_device); + } + } + size_t i = 0; // We then assign register num to the free variables for (auto param : func->params) { @@ -263,7 +322,19 @@ class VMFunctionCompiler : ExprFunctor { this->VisitExpr(func->body); } instructions_.push_back(Instruction::Ret(last_register_)); - return VMFunction(var->name_hint, params_, instructions_, registers_num_); + + std::vector params_device_type; + for (const auto& it : func->params) { + if (!expr_device_map_.empty()) { + CHECK_GT(expr_device_map_.count(it), 0U); + params_device_type.push_back(expr_device_map_[it].device_type); + } else { + CHECK_EQ(targets_.size(), 1U); + params_device_type.push_back((targets_.begin())->first); + } + } + + return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type); } protected: @@ -287,6 +358,7 @@ class VMFunctionCompiler : ExprFunctor { case Opcode::ReshapeTensor: case Opcode::Move: case Opcode::InvokeClosure: + case Opcode::DeviceCopy: last_register_ = instr.dst; break; case Opcode::InvokePacked: @@ -310,6 +382,13 @@ class VMFunctionCompiler : ExprFunctor { } } size_t konst_idx = context_->constants.size(); + if (expr_device_map_.empty()) { + context_->const_device_type.push_back(targets_.begin()->first); + } else { + auto con = GetRef(const_node); + CHECK_GT(expr_device_map_.count(con), 0U); + context_->const_device_type.push_back(expr_device_map_[con].device_type); + } context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } @@ -477,13 +556,23 @@ class VMFunctionCompiler : ExprFunctor { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. - if (targets_.size() == 1) { + if (expr_device_map_.empty()) { // homogeneous execution. + CHECK_EQ(targets_.size(), 1U); const auto& it = targets_.begin(); target = (*it).second; } else { - // heterogeneous execution. - LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; + if (expr_device_map_.count(func) == 0 || + targets_.count(expr_device_map_[func].device_type) == 0) { + int fallback_dev = GetFallbackDevice(); + auto dev_name = runtime::DeviceName(fallback_dev); + if (expr_device_map_.count(func) == 0) { + LOG(WARNING) << "The function is not annotated. Fallback to " << dev_name; + } + target = CreateDefaultTarget(fallback_dev); + } else { + target = targets_[expr_device_map_[func].device_type]; + } } } @@ -561,7 +650,8 @@ class VMFunctionCompiler : ExprFunctor { } }) .Match("memory.alloc_storage", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { CHECK_EQ(args.size(), 2); // Compute the size of the allocation. this->VisitExpr(args[0]); @@ -577,10 +667,23 @@ class VMFunctionCompiler : ExprFunctor { // Get the dtype hint from the attributes. auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + CHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; auto dtype = alloc_attrs->dtype; - Emit(Instruction::AllocStorage(size_register, alignment, dtype, NewRegister())); + Index device_type; + // There is bug if all expression are annotated with the device that + // other than the first one in the target list. + if (expr_device_map_.empty()) { + auto& kv = *(targets_.begin()); + device_type = kv.first; + } else { + CHECK_GT(expr_device_map_.count(GetRef(call_node)), 0U) + << " The alloc_storage node is not annotated"; + device_type = expr_device_map_[GetRef(call_node)].device_type; + } + + Emit(Instruction::AllocStorage(size_register, alignment, dtype, device_type, + NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -611,6 +714,19 @@ class VMFunctionCompiler : ExprFunctor { auto shape_reg = last_register_; Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); }) + .Match("device_copy", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 1U); + this->VisitExpr(args[0]); + auto src_reg = last_register_; + + auto device_copy_attrs = attrs.as(); + CHECK(device_copy_attrs != nullptr) << "Must be the device copy attrs"; + Index src_device_type = device_copy_attrs->src_dev_type; + Index dst_device_type = device_copy_attrs->dst_dev_type; + Emit(Instruction::DeviceCopy(src_reg, src_device_type, dst_device_type, + NewRegister())); + }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { LOG(FATAL) << "memory.kill is not yet supported"; @@ -769,9 +885,11 @@ class VMFunctionCompiler : ExprFunctor { /*! \brief Global shared meta data */ VMCompilerContext* context_; /*! \brief Target devices. */ - TargetsMap targets_; + std::unordered_map targets_; /*! \brief Host target. */ Target target_host_; + /*! \brief Map from Relay expr to device type. */ + ExprDeviceMap expr_device_map_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { @@ -820,7 +938,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { } void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { - CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) @@ -871,6 +988,10 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe exec_->constants.push_back(data); } + for (auto i : context_.const_device_type) { + exec_->const_device_type.push_back(i); + } + // update global function map for (auto gv : context_.global_map) { exec_->global_map.insert({gv.first->name_hint, gv.second}); @@ -883,10 +1004,10 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe } } -transform::Sequential MemoryOpt(tvm::Target host_target) { +transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { Array pass_seqs; // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -895,25 +1016,25 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { pass_seqs.push_back(transform::FuseOps()); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); + // // Fuse the shape functions. + // pass_seqs.push_back(transform::FuseOps()); - // Perform memory planning in order to coalesce/reduce allocations. - pass_seqs.push_back(transform::MemoryPlan()); + // // Perform memory planning in order to coalesce/reduce allocations. + // pass_seqs.push_back(transform::MemoryPlan()); - // Compute away constant computation introduced by coalescing allocations. - pass_seqs.push_back(transform::FoldConstant()); + // // Compute away constant computation introduced by coalescing allocations. + // pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); + // // Fuse the shape functions. + // pass_seqs.push_back(transform::FuseOps()); - // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + // // Create allocations for math introduced by dynamic region math. + // pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); + // // Compute away possibly introduced constant computation. + // pass_seqs.push_back(transform::FoldConstant()); // Lift constants to the top-level of the block to simplify VM code generation. // TODO(@icemelon9, @jroesch): Remove this pass for now because some @@ -977,6 +1098,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + if (targets_.size() > 1) { + // Handle heterogeneous compilation. + int fallback_dev = GetFallbackDevice(); + pass_seqs.push_back(transform::RewriteAnnotatedOps(fallback_dev)); + } + pass_seqs.push_back(transform::FuseOps()); pass_seqs.push_back(transform::ToANormalForm()); pass_seqs.push_back(transform::LambdaLift()); @@ -989,11 +1116,10 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(target_host)); + pass_seqs.push_back(MemoryOpt(target_host, targets)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); - // TODO(wweic): Support heterogenous execution tvm::With ctx(pass_ctx); if (targets.size() == 1) { const auto& it = targets.begin(); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index b4b86d3d6d8e..78cff7ba8d4f 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -76,6 +76,8 @@ struct VMCompilerContext { GlobalMap global_map; // List of constants std::vector constants; + // Device type for constants + std::vector const_device_type; // List of cached functions std::vector cached_funcs; // The functions that have been lowered. diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 923965f98192..3a58607e6dd8 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include "../transforms/infer_layout_util.h" #include "type_relations.h" @@ -60,7 +61,12 @@ on different devices. .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 7273c28a0e93..bdc613d85cdb 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -79,6 +79,7 @@ class ConstantFolder : public ExprMutator { public: explicit ConstantFolder(IRModule module) : module_(module), + device_copy_op_(Op::Get("device_copy")), shape_of_op_(Op::Get("shape_of")), vm_shape_of_op_(Op::Get("vm.shape_of")), invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")), @@ -134,7 +135,7 @@ class ConstantFolder : public ExprMutator { // We should think about potentially constant evaluation over these ops too. if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || - call->op == alloc_storage_op_) { + call->op == alloc_storage_op_ || call->op == device_copy_op_) { return GetRef(call); } @@ -168,6 +169,7 @@ class ConstantFolder : public ExprMutator { IRModule module_; // Cache the following ops for equivalence checking in this pass. + const Op& device_copy_op_; const Op& shape_of_op_; const Op& vm_shape_of_op_; const Op& invoke_tvm_op_; diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index edfd3acfb3e2..754858fb5d0e 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -123,6 +123,11 @@ Instruction::Instruction(const Instruction& instr) { this->reshape_tensor.tensor = instr.reshape_tensor.tensor; this->reshape_tensor.newshape = instr.reshape_tensor.newshape; return; + case Opcode::DeviceCopy: + this->src = instr.src; + this->src_device_type = instr.src_device_type; + this->dst_device_type = instr.dst_device_type; + return; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -220,6 +225,15 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::ShapeOf: this->shape_of.tensor = instr.shape_of.tensor; return *this; + case Opcode::ReshapeTensor: + this->reshape_tensor.tensor = instr.reshape_tensor.tensor; + this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + return *this; + case Opcode::DeviceCopy: + this->src = instr.src; + this->src_device_type = instr.src_device_type; + this->dst_device_type = instr.dst_device_type; + return *this; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -241,6 +255,7 @@ Instruction::~Instruction() { case Opcode::AllocStorage: case Opcode::ShapeOf: case Opcode::ReshapeTensor: + case Opcode::DeviceCopy: case Opcode::Fatal: return; case Opcode::AllocTensor: @@ -324,13 +339,14 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName } Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - RegName dst) { + Index device_type, RegName dst) { Instruction instr; instr.op = Opcode::AllocStorage; instr.dst = dst; instr.alloc_storage.allocation_size = size; instr.alloc_storage.alignment = alignment; instr.alloc_storage.dtype_hint = dtype_hint; + instr.alloc_storage.device_type = device_type; return instr; } @@ -351,6 +367,17 @@ Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName return instr; } +Instruction Instruction::DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + RegName dst) { + Instruction instr; + instr.op = Opcode::DeviceCopy; + instr.dst = dst; + instr.src = src; + instr.src_device_type = src_device_type; + instr.dst_device_type = dst_device_type; + return instr; +} + Instruction Instruction::AllocADT(Index tag, Index num_fields, const std::vector& datatype_fields, RegName dst) { Instruction instr; @@ -582,7 +609,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { case Opcode::AllocStorage: { os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " " << instr.alloc_storage.alignment << " " - << DLDataType2String(instr.alloc_storage.dtype_hint); + << DLDataType2String(instr.alloc_storage.dtype_hint) << " " + << instr.alloc_storage.device_type; break; } case Opcode::ShapeOf: { @@ -594,6 +622,11 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << instr.reshape_tensor.newshape; break; } + case Opcode::DeviceCopy: { + os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.src_device_type << " " + << instr.dst_device_type; + break; + } default: LOG(FATAL) << "should never hit this case" << static_cast(instr.op); break; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 998762187dfc..ef2091746795 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -351,6 +351,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(dtype.code); fields.push_back(dtype.bits); fields.push_back(dtype.lanes); + fields.push_back(instr.alloc_storage.device_type); fields.push_back(instr.dst); break; } @@ -428,6 +429,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst}); break; } + case Opcode::DeviceCopy: { + // Number of fields = 4 + fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst}); + break; + } default: LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); break; @@ -631,9 +637,10 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { dtype.bits = instr.fields[3]; dtype.lanes = instr.fields[4]; - RegName dst = instr.fields[5]; + Index device_type = instr.fields[5]; + RegName dst = instr.fields[6]; - return Instruction::AllocStorage(allocation_size, alignment, dtype, dst); + return Instruction::AllocStorage(allocation_size, alignment, dtype, device_type, dst); } case Opcode::If: { // Number of fields = 4 @@ -704,6 +711,12 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { DCHECK_EQ(instr.fields.size(), 3U); return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]); } + case Opcode::DeviceCopy: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2], + instr.fields[3]); + } default: LOG(FATAL) << "Invalid opcode" << instr.opcode; return Instruction(); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 7273b565cd69..32aedc527e24 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -105,9 +105,12 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { CHECK(exec_); - auto ctx = this->GetParamsContext(); - // warmup - VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); + CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; + // TODO(@zhiics) Need to record the device type of each packed func so that + // we can correctly sync. + Index fallback_device_type = static_cast(ctxs_[0].device_type); + auto ctx = this->GetContext(fallback_device_type); + TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_begin = std::chrono::high_resolution_clock::now(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 9af520228fee..b784f3a12737 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -146,12 +146,15 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, auto func_index = gvit->second; const auto& vm_func = exec_->functions[func_index]; const auto& param_names = vm_func.params; - // TODO(icemelon9): For heterogeneous execution, get input device information - TVMContext ctx = ctxs_[0]; CHECK_EQ(args.size() - 1, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { + TVMContext ctx; + int device_type = vm_func.params_device_type[i-1]; + ctx.device_type = DLDeviceType(device_type); + //TODO(zhiics) How to decide which device id? + ctx.device_id = 0; ObjectRef obj = CopyTo(args[i], ctx); func_args[i - 1] = obj; } @@ -164,18 +167,15 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } } -TVMContext VirtualMachine::GetParamsContext() const { +TVMContext VirtualMachine::GetContext(Index device_type) const { CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; - // Use the fallback device if no device index is available. - int fallback_device_type = static_cast(ctxs_[0].device_type); - // TODO(wweic): For heterogeneous execution, get device information from byte + const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&device_type](const TVMContext& c) { + return device_type == static_cast(c.device_type); + }); - const auto& cit = - std::find_if(ctxs_.begin(), ctxs_.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); - return (cit == ctxs_.end() ? ctxs_[0] : *cit); + CHECK(cit != ctxs_.end()) << "device type " << device_type << " not found int the context list."; + return *cit; } void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { @@ -364,8 +364,8 @@ void VirtualMachine::RunLoop() { } if (!const_pool_[instr.const_index].defined()) { - // TODO(wweic) ctx could be obtained from the ctxs list. - const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs_[0]); + TVMContext ctx = GetContext(exec_->const_device_type[instr.const_index]); + const_pool_[instr.const_index] = CopyTo(constant_obj, ctx); } WriteRegister(instr.dst, const_pool_[instr.const_index]); pc_++; @@ -511,11 +511,13 @@ void VirtualMachine::RunLoop() { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; - DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment - << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment + << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) + << ", device_type=" << instr.alloc_storage.device_type; auto storage_obj = SimpleObjAllocator().make_object(); - auto it = allocators_.find(ctxs_[0]); + auto ctx = GetContext(instr.alloc_storage.device_type); + auto it = allocators_.find(ctx); CHECK(it != allocators_.end()) << "Did you forget to init the VirtualMachine with contexts?"; auto alloc = it->second; @@ -573,6 +575,21 @@ void VirtualMachine::RunLoop() { pc_++; goto main_loop; } + case Opcode::DeviceCopy: { + auto tensor_src = ReadRegister(instr.src); + NDArray src_data = Downcast(tensor_src); + DLContext src_ctx = src_data->ctx; + CHECK_EQ(static_cast(src_ctx.device_type), instr.src_device_type); + + DLContext dst_ctx; + dst_ctx.device_type = static_cast(instr.dst_device_type); + dst_ctx.device_id = 0; + + NDArray dst_data = src_data.CopyTo(dst_ctx); + WriteRegister(instr.dst, dst_data); + pc_++; + goto main_loop; + } default: LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op); } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 7f373369d301..e3eb4e8d654f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,6 +23,7 @@ from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type import tvm.topi.testing +from tvm.relay.testing.config import ctx_list def int32(val): return relay.const(val, 'int32') @@ -33,8 +34,22 @@ def any_dims(ndim): shape.append(relay.Any()) return tuple(shape) -# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test -# shape function on CPU. +def check_result(args, mod, expected, flatten=False, assert_shape=False): + for kind in ["vm"]: + for tgt, ctx in ctx_list(): + ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt) + result = ex.evaluate()(*args) + result = result.asnumpy() + if assert_shape: + assert result.shape == expected, \ + "Shape mismatch: expect %s but got %s." \ + % (str(expected), str(result.shape)) + return + + if flatten: + result = result.flatten() + expected = expected.flatten() + tvm.testing.assert_allclose(result, expected) def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): dtype = 'float32' @@ -45,10 +60,7 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): x_np = np.random.uniform(size=x_np_shape).astype(dtype) y_np = np.random.uniform(size=y_np_shape).astype(dtype) res_np = np_op(x_np, y_np) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(result.asnumpy(), res_np) + check_result([x_np, y_np], mod, res_np) def test_any_broadcast(): # Test broadcast with 1s @@ -69,10 +81,7 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): mod["main"] = relay.Function([x], op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np) - tvm.testing.assert_allclose(result.asnumpy(), res_np) + check_result([x_np], mod, res_np) def test_any_elemwise(): verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) @@ -103,10 +112,7 @@ def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): mod['main'] = relay.Function([x], relay_op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) - for kind in ['debug', 'vm']: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') - result = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(result, res_np) + check_result([x_np], mod, res_np) def test_any_full_like(): # zeros_like, ones_like @@ -124,10 +130,7 @@ def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None): mod['main'] = relay.Function([x], out) res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value) x_np = np.array(x_np_shape).astype("int32") - for kind in ['debug', 'vm']: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') - result = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(result, res_np) + check_result([x_np], mod, res_np) def test_any_full(): # zeros, ones, full @@ -151,10 +154,7 @@ def test_any_concat(): x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(result.asnumpy(), ref) + check_result([x_np, y_np], mod, ref) def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var('x', shape=x_shape, dtype="float32") @@ -172,6 +172,7 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha y = relay.reshape(relu_x, newshape=newshape) mod = tvm.IRModule() mod["main"] = relay.Function(params, y) + # check_result(args, mod, data, flatten=True) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") @@ -195,11 +196,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): mod["main"] = relay.Function([x], y) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) expected = np.argwhere(data) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data).asnumpy() - assert result.shape == expected.shape - tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + check_result([data], mod, expected, flatten=True) def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) @@ -231,10 +228,7 @@ def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_s max_index = data_np.shape[axis] indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32') ref = np.take(data_np, indices_np, axis=axis) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, indices_np) - tvm.testing.assert_allclose(result.asnumpy(), ref) + check_result([data_np, indices_np], mod, ref) def test_any_take(): verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,)) @@ -251,11 +245,7 @@ def verify_any_tile(dshape, reps, np_dshape, np_reps): mod["main"] = relay.Function([x], y) x_data = np.random.uniform(size=np_dshape).astype("float32") ref_res = np.tile(x_data, reps=np_reps) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - res = ex.evaluate()(x_data) - tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5) + check_result([x_data], mod, ref_res) def test_any_tile(): verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1)) @@ -269,10 +259,7 @@ def test_any_shape_of(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y) data = np.random.uniform(size=(3, 4)).astype('float32') - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64")) + check_result([data], mod, np.array([3,4]).astype("int64")) x = relay.var('x', shape=any_dims(3), dtype='float32') y0 = relay.shape_of(x) @@ -280,10 +267,7 @@ def test_any_shape_of(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(2, 3, 4)).astype('float32') - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64")) + check_result([data], mod, np.array(3).astype("int64")) def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, static_data_shape, ref_out_shape): @@ -293,11 +277,7 @@ def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, y = reduce_op(data, axis, keepdims, exclude) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_reduce(): verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ()) @@ -316,11 +296,7 @@ def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_ y = relay.layout_transform(data, src_layout, dst_layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_layout_transform(): verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4)) @@ -336,11 +312,7 @@ def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_expand_dims(): verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3)) @@ -354,14 +326,12 @@ def verify_any_transpose(data_shape, axes, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.transpose(data_np, axes) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_transpose(): verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2)) verify_any_transpose(any_dims(3), None, (2, 3, 4)) + # TODO(@zhiics) This test hangs, debug verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17)) verify_any_transpose(any_dims(2), (-1, 0), (3, 2)) @@ -373,10 +343,7 @@ def verify_any_squeeze(data_shape, axis, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.squeeze(data_np, axis) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_squeeze(): verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8)) @@ -391,11 +358,7 @@ def test_any_reshape_like(): mod["main"] = relay.Function([data, shape_like], y) data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, shape_like_np) - assert result.asnumpy().shape == shape_like_np.shape, \ - "Shape mismatch: expect %s but got %s." % (str(shape_like_np.shape), str(result.asnumpy().shape)) + check_result([data_np, shape_like_np], mod, shape_like_np.shape, assert_shape=True) def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation, data_layout, kernel_layout, out_layout, @@ -412,11 +375,7 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, kernel_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) # TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any @pytest.mark.skip @@ -435,11 +394,7 @@ def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding, y = pool_func(data, pool_size, strides, padding, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_pool2d(): verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), @@ -457,11 +412,7 @@ def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, r y = pool_func(data, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_global_pool2d(): verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), @@ -499,11 +450,8 @@ def test_any_batch_flatten(): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) ref_out_shape = (3, 30) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + # TODO(@zhiics) Check dense schedule + check_result([data_np], mod, ref_out_shape, assert_shape=True) def verify_any_dense(data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape): @@ -515,11 +463,7 @@ def verify_any_dense(data_shape, weight_shape, units, static_data_shape, mod["main"] = relay.Function([data, weight], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) weight_np = np.random.uniform(size=static_weight_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, weight_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True) def test_any_dense(): verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8)) @@ -533,10 +477,7 @@ def verify_any_pad(data_shape, pad_width, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.pad(data_np, pad_width) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_pad(): verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) @@ -554,11 +495,7 @@ def verify_any_dilate(data_shape, strides, static_data_shape): for i in range(len(static_data_shape))) ref_out = np.zeros(shape=ref_shape, dtype=dtype) ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_dilate(): verify_any_dilate(any_dims(1), (1,), (1,)) @@ -577,11 +514,7 @@ def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): y = relay.nn.softmax(data, axis) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_softmax(): verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3)) @@ -608,10 +541,8 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): else: ref_out = sorted[0:kval] - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*in_vals) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + # TODO(@zhiics) check topk cuda schedule + check_result(in_vals, mod, ref_out) def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") @@ -625,10 +556,7 @@ def test_fused_ops(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(5, 4)).astype('float32') - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2) + check_result([data], mod, (data + 1) * 2) def test_arange_with_dynamic_shape(): # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') @@ -641,10 +569,7 @@ def test_arange_with_dynamic_shape(): data = np.random.rand(10, 5, 3).astype('float32') mod = tvm.IRModule() mod["main"] = relay.Function([x], y3) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1) + check_result([data], mod, np.array(range(10)).astype("int32")+1) def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, data_np_shape, slice_mode="end", const_attrs=False): @@ -677,11 +602,7 @@ def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, strides=strides, slice_mode=slice_mode) mod["main"] = relay.Function(args, y) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*np_inputs) - tvm.testing.assert_allclose(result.asnumpy(), ref_res) - + check_result(*np_inputs, mod, ref_res) def test_any_strided_slice(): verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21)) @@ -724,10 +645,7 @@ def _body(i, st): mod["main"] = func data = np.array(0.0, dtype='int32') ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - np.testing.assert_allclose(result.asnumpy(), ref) + check_result([data], mod, ref) def test_recursive_concat_with_wrong_annotation(): """ @@ -789,11 +707,7 @@ def test_tuple_get_item(): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 2) - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_mixed_input_type(): mod = tvm.IRModule() @@ -811,11 +725,8 @@ def test_mixed_input_type(): data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 4) - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + # TODO(@zhiics) FAILED + check_result([data_np0, data_np0], mod, ref_out_shape, assert_shape=True) def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, layout, static_boxes, static_box_indices_shape, ref_out_shape): @@ -829,11 +740,8 @@ def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_ mod["main"] = relay.Function([data, boxes, box_indices], y) data_np = np.random.uniform(size=data_shape).astype(dtype) boxes_np = np.random.uniform(size=static_boxes).astype(dtype) - box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, boxes_np, box_indices_np) - tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype) + check_result([data_np, boxes_np, box_indices_np], mod, ref_out_shape, assert_shape=True) def test_any_crop_and_resize(): verify_any_crop_and_resize( @@ -863,10 +771,7 @@ def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shap y = relay.nn.mirror_pad(data, pad_width) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_mirror_pad(): verify_any_mirror_pad( @@ -882,11 +787,7 @@ def verify_any_ndarray_size(data_np_shape): mod['main'] = relay.Function([v], n) np_data = np.zeros(data_np_shape, dtype='float32') ref_res = np.size(np_data) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(np_data) - tvm.testing.assert_allclose(result.asnumpy(), ref_res) + check_result([np_data], mod, ref_res) def test_any_ndarray_size(): verify_any_ndarray_size((2,)) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 7a2ff55790a7..c55120e4b7cf 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -25,6 +25,41 @@ from tvm.relay import transform import tvm.testing +def _trace(module, metadata, _): + if metadata.name == 'ManifestAlloc': + pass # import pdb; pdb.set_trace() + + +def check_graph_runtime(target, ref_res, device, func, params, config, + opt_level, expected_index=None): + with tvm.transform.PassContext(opt_level=opt_level, config=config): + graph, lib, new_params = relay.build( + func, + target, + params=params) + contexts = [tvm.cpu(0), tvm.context(device)] + graph_json = json.loads(graph) + if "device_index" in graph_json["attrs"]: + device_index = graph_json["attrs"]["device_index"][1] + assert device_index == expected_index + mod = graph_runtime.create(graph, lib, contexts) + mod.set_input(**new_params) + mod.run() + res = mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + + +def check_vm_runtime(target, ref_res, device, func, params, config, + opt_level, expected_index=None): + with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config): + mod = tvm.IRModule() + mod["main"] = func + exe = relay.vm.compile(mod, target) + ctx = [tvm.cpu(0), tvm.context(device)] + vm = tvm.runtime.vm.VirtualMachine(exe, ctx) + res = vm.invoke("main", **params) + tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) @@ -400,6 +435,7 @@ def run_fusible_network(dev, tgt): tmp_log = np.log(tmp_add) tmp_sub = np.subtract(tmp_sqrt, tmp_log) ref_res = np.exp(tmp_sub) + params = {"x": x_data, "y": y_data} def get_func(): add = relay.add(x, y) @@ -411,28 +447,6 @@ def get_func(): func = relay.Function([x, y], exp) return func - def test_runtime(target, device, func, fallback_device=None, - expected_index=None): - params = {"x": x_data, "y": y_data} - config = {} - if fallback_device: - config["relay.fallback_device_type"] = fallback_device.device_type - with tvm.transform.PassContext(opt_level=1, config=config): - graph, lib, params = relay.build( - func, - target, - params=params) - contexts = [tvm.cpu(0), tvm.context(device)] - graph_json = json.loads(graph) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_runtime.create(graph, lib, contexts) - mod.set_input(**params) - mod.run() - res = mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) - def test_fuse_log_add(device, tgt): """ Only log and add are fused.""" fallback_device = tvm.context("cpu") @@ -473,8 +487,13 @@ def expected(): dev_idx = ctx.device_type expected_index = [1, 1, 1, dev_idx, dev_idx, 1, 1, dev_idx, dev_idx] check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device, - expected_index) + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} + check_graph_runtime(target, ref_res, device, annotated_func, params, + config, opt_level, expected_index) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, + config, opt_level, expected_index) def test_fuse_all(device, tgt): """Fuse all operators.""" @@ -503,7 +522,13 @@ def annotated(): annotated_func = annotated() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device) + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} + check_graph_runtime(target, ref_res, device, annotated_func, params, + config, opt_level) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, + config, opt_level) def test_fallback_exp(device, tgt): fallback_device = tvm.context("cpu") @@ -540,16 +565,25 @@ def expected(): ctx = tvm.context(device, 0) dev_idx = ctx.device_type expected_index = [dev_idx, dev_idx, dev_idx, 1, 1] + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device, - expected_index) + check_graph_runtime(target, ref_res, device, annotated_func, params, config, + opt_level, expected_index) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, config, + opt_level, expected_index) def test_fallback_all_operators(device, tgt): target = {device: tgt, "cpu": "llvm"} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func) + opt_level = 2 + check_graph_runtime(target, ref_res, device, annotated_func, params, {}, + opt_level) + check_vm_runtime(target, ref_res, device, annotated_func, params, {}, + opt_level) test_fuse_log_add(dev, tgt) @@ -557,6 +591,7 @@ def test_fallback_all_operators(device, tgt): test_fallback_exp(dev, tgt) test_fallback_all_operators(dev, tgt) + def run_unpropagatable_graph(dev, tgt): R""" The network is as following: a b c d @@ -608,20 +643,15 @@ def expected(): expected_index = [2, 2, 2, 1, 1, 1, 2, 2] check_annotated_graph(annotated_func, expected_func) params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - with tvm.transform.PassContext(opt_level=0, - config={"relay.fallback_device_type": - fallback_device.device_type}): - graph, lib, params = relay.build(annotated_func, target, params=params) - contexts = [tvm.cpu(0), tvm.context(dev)] - graph_json = json.loads(graph) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_runtime.create(graph, lib, contexts) - mod.set_input(**params) - mod.run() - res = mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + opt_level = 0 + config = {"relay.fallback_device_type": fallback_device.device_type} + + check_graph_runtime(target, ref_res, dev, annotated_func, params, config, + opt_level, expected_index) + + opt_level = 2 + check_vm_runtime(target, ref_res, dev, annotated_func, params, config, + opt_level) @tvm.testing.requires_opencl @@ -686,5 +716,4 @@ def annotated(): test_annotate_all() test_annotate_none() test_conv_network() - test_check_run() test_tuple_get_item() diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py new file mode 100644 index 000000000000..a133a8cbf458 --- /dev/null +++ b/tests/python/relay/test_pass_context_analysis.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks + +import numpy as np + +import tvm +from tvm import relay + +data0 = relay.var("data0", shape=(1, relay.Any())) +data1 = relay.var("data1", shape=(1, relay.Any())) + +r0 = relay.cast(data0, dtype="int32") +w0 = relay.const(np.ndarray(shape=(30522, 768), dtype="float32")) +r1 = relay.take(w0, r0, axis=0) +r2 = relay.cast(data1, dtype="int32") +w1 = relay.const(np.ndarray(shape=(2, 768), dtype="float32")) +r3 = relay.take(w1, r2, axis=0) +r4 = relay.add(r1, r3) +r5 = relay.transpose(r4, axes=[1, 0, 2]) +r6 = relay.shape_of(r5, dtype="int32") +r7 = relay.take(r6, relay.const(0, dtype="int32")) +r8 = relay.cast(r7, dtype="float32") +r9 = relay.multiply(relay.const(1, dtype="float32"), r8) +r10 = relay.add(relay.const(0, dtype="float32"), r9) +r11 = relay.arange(relay.const(0, dtype="float32"), r10,\ + relay.const(1, dtype="float32"), dtype="float32") +r12 = relay.cast(r11, dtype="int32") +w2 = relay.const(np.ndarray(shape=(512, 768), dtype="float32")) +r13 = relay.take(w2, r12, axis=0) +r14 = relay.expand_dims(r13, axis=1) +r15 = relay.add(r5, r14) +r16 = relay.nn.dropout(r15, rate=0.1) +# r17 = relay.TupleGetItem(r16.astuple(), 0) +w3 = relay.const(np.ndarray(shape=(768,), dtype="float32")) +w4 = relay.const(np.ndarray(shape=(768,), dtype="float32")) +r18 = relay.nn.layer_norm(r16, w3, w4, epsilon=1e-12) +r19 = relay.op.reverse_reshape(r18, newshape=[-1, 0]) +w5 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) +r20 = relay.reverse_reshape(w5, newshape=[12, -1, 0]) +w6 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) +r21 = relay.reverse_reshape(w6, newshape=[12, -1, 0]) +w7 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) +r22 = relay.reverse_reshape(w7, newshape=[12, -1, 0]) +r23 = relay.Tuple([r20, r21, r22]) +r24 = relay.concatenate(r23, axis=-2) +r25 = relay.reverse_reshape(r24, newshape=[-1, 0]) +r26 = relay.nn.dense(r19, r25, units=2304) +w8 = relay.const(np.ndarray(shape=(768,), dtype="float32")) +w9 = relay.const(np.ndarray(shape=(768,), dtype="float32")) +w10 = relay.const(np.ndarray(shape=(768,), dtype="float32")) +r27 = relay.Tuple([w8, w9, w10]) +r28 = relay.concatenate(r27, axis=0) +r29 = relay.nn.bias_add(r26, r28, axis=-1) +r30 = relay.reshape(r29, newshape=[-1, 1, 2304]) +r31 = relay.reshape(r30, newshape=[0, 0, 12, 3, -1]) +r32 = relay.take(r31, relay.const(0, dtype="int64"), axis=3) +r33 = relay.transpose(r32, axes=[1, 2, 0, 3]) +r34 = relay.reverse_reshape(r33, newshape=[-1, 0, 0]) +r35 = relay.shape_of(r34, dtype="int32") +r36 = relay.take(r35, relay.const(2, dtype="int32")) +r37 = relay.cast(r36, dtype="float32") +r38 = relay.sqrt(r37) +r39 = relay.divide(r34, r38) +r40 = relay.take(r31, relay.const(1, dtype="int64"), axis=3) +r41 = relay.transpose(r40, axes=[1, 2, 0, 3]) +r42 = relay.reverse_reshape(r41, newshape=[-1, 0, 0]) +r43 = relay.nn.batch_matmul(r39, r42) +# r44 = relay.nn.softmax(r43) + +func = relay.Function([data0, data1], r43) +mod = tvm.ir.IRModule.from_expr(func) + +params = {} +exe = relay.vm.compile(mod, target="cuda", params=params) +rt = tvm.runtime.vm.VirtualMachine(exe, tvm.gpu(0)) + +seq_length = 128 +d0 = np.random.randint(0, 1000, size=(1, seq_length)).astype('float32') +d1 = np.ones((1, seq_length)).astype('float32') +d2 = np.asarray([seq_length]).astype('float32') + +rt.set_input("main", data0=d0, data1=d1) + +rt.run() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 710025aeadb3..aac8f3529be5 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -661,7 +661,7 @@ def test_vm_reshape_tensor(): assert "reshape_tensor" in exec.bytecode check_result([x_np], x_np.reshape([4, 4, 8]), mod) - x = relay.var("x", shape=(8, 16), dtype="float32") + x = relay.var("x", shape=(tvm.tir.Any(), 16), dtype="float32") y = relay.reshape(x, [16, -1]) y = relay.reverse_reshape(y, [-1, 4, 0]) mod = tvm.IRModule() From 9cd7c7872716f6991609a7259af74b461c212125 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 5 Aug 2020 16:47:58 +0000 Subject: [PATCH 02/21] context analysis on module --- python/tvm/ir/module.py | 14 ++ python/tvm/relay/analysis/context_analysis.py | 169 +++++++++++++++++- python/tvm/relay/op/_tensor.py | 2 +- python/tvm/relay/op/_transform.py | 14 +- python/tvm/relay/transform/memory_alloc.py | 37 ++-- python/tvm/relay/transform/memory_plan.py | 8 - src/ir/module.cc | 3 + src/relay/backend/vm/compiler.cc | 55 +++--- src/relay/backend/vm/compiler.h | 20 ++- src/runtime/vm/bytecode.cc | 4 +- src/runtime/vm/vm.cc | 17 +- tests/python/relay/test_any.py | 42 +++-- tests/python/relay/test_vm.py | 2 +- 13 files changed, 293 insertions(+), 94 deletions(-) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 8d75d8e8ee21..851e40bd0ad2 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -118,6 +118,20 @@ def update(self, other): other = Module(other) return _ffi_api.Module_Update(self, other) + def update_func(self, gv, func): + """Update the function corresponding to a global variable in the + module. + + Parameters + ---------- + gv: GlobalVar + The global variable. + + func: tvm.relay.Function + The function to be inserted. + """ + return _ffi_api.Module_UpdateFunction(self, gv, func) + def get_global_var(self, name): """Get a global variable in the function by name. diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index ac7a146312f7..e74badbfc8f7 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -21,18 +21,37 @@ from typing import Optional from collections import defaultdict +import tvm from ..expr_functor import ExprVisitor from ..function import Function -from .. import op, expr as _expr +from .. import ty, op, expr as _expr from ... import register_func, cpu from ..._ffi.runtime_ctypes import TVMContext +def is_closure(func): + """Check if a function is a closure. + + Parameters + ---------- + func : tvm.relay.Function + The input function. + + Returns + ------- + True if the input function is a closure, otherwise false. + """ + return hasattr(func, 'attrs') and \ + hasattr(func.attrs, 'Closure') and int(func.attrs.Closure) == 1 + + def is_device_copy(call): """Check if a call node is a device copy call. + Parameters ---------- call : tvm.relay.Call The call node to be checked. + Returns ------- ret : Boolean @@ -48,9 +67,11 @@ def is_device_copy(call): return isinstance(call.op.body, _expr.Call) and \ call.op.body.op == op.op.get("device_copy") + class DeviceDomain: """A class to represent the device of a domain, i.e. a segment of relay program. + Parameters ---------- ctx : Optional[tvm.runtime.TVMContext] @@ -61,10 +82,12 @@ def __init__(self, ctx: Optional[TVMContext]): def join(self, other: 'DeviceDomain') -> 'DeviceDomain': """Merge the device of two domains. + Parameters ---------- other : DeviceDomain The other domain to be merged. + Returns ------- ret : DeviceDomain @@ -105,10 +128,12 @@ def bottom(): def device_type(ctx): """Create a domain with the given device context. + Parameters ---------- ctx : tvm.runtime.TVMContext The device context used to construct a domain. + Returns ------- ret : DeviceDomain @@ -120,23 +145,35 @@ def device_type(ctx): class ContextAnalysis(ExprVisitor): """Compute on which device each sub-expression will execute. A union find algorithm is used to assign and merge the context domains. + Parameters ---------- - fallback_device : tvm.rutnime.TVMContext + mod : tvm.IRModule + The module that helps context analysis. + + current_func : tvm.relay.GlobalVar + The current function that is being analyzed. + + fallback_device : tvm.runtime.TVMContext The default device that could be attached to an expression. """ - def __init__(self, fallback_device): + def __init__(self, mod, current_func, fallback_device): super().__init__() self.expr_to_device = defaultdict(bottom) self.device_uf = {} + self.mod = mod + self.closures = {} + self.current_func = current_func self.fallback_device = fallback_device def lookup(self, device): """Find the root domain of a given device domain. + Parameters ---------- device : DeviceDomain The domain that is used to query the root domain. + Returns ------- ret : DeviceDomain @@ -148,12 +185,15 @@ def lookup(self, device): def unify(self, lhs, rhs): """Unify the device context of two domains. + Parameters ---------- lhs : DeviceDomain The lhs domain to unify. + rhs : DeviceDomain The rhs domain to unify. + Returns ------- ret : DeviceDomain @@ -170,12 +210,15 @@ def unify(self, lhs, rhs): def unify_expr(self, lhs, rhs): """Compute the device type of both expressions and unify them. + Parameters ---------- lhs : tvm.relay.Expr The lhs expression to unify. + rhs : tvm.relay.Expr The rhs expression to unify. + Returns ------- ret : DeviceDomain @@ -185,10 +228,12 @@ def unify_expr(self, lhs, rhs): def device_for(self, expr): """Find the domain that contains the given expr. + Parameters ---------- expr : tvm.relay.Expr The expression used to lookup a domain. + Returns ------- ret : DeviceDomain @@ -200,18 +245,22 @@ def device_copy(self, inps, outputs, src_dev_type, dst_dev_type): """Unify the device context for device copy node. Device copy node is the only node that carries information in the input program. The device attribute of other nodes are propagated from it. + Parameters ---------- inps : List[tvm.relay.Expr] The input expression to the device copy node. The device type of the input should be the same as the source device type of the copy node. + outputs : List[tvm.relay.Expr] The output expression of the device copy node. The device type of the output should be the same as the destination device type of the copy node. + src_dev_type : int The source device type of the copy node. + dst_dev_type : int The destination device type of the copy node. """ @@ -225,21 +274,26 @@ def device_copy(self, inps, outputs, src_dev_type, dst_dev_type): def unify_call(self, call_op, inputs, outputs, device=None): """Unify the domain of inputs and outputs of a relay Call. + Parameters ---------- op : tvm.relay.Expr The op of a call node. + inputs : List[tvm.relay.Expr] The inputs of the call. + outputs : List[tvm.relay.Expr] The outputs of the call. + Returns ------- The unified domain. + Note ---- For most call nodes, the op, inputs, and outputs should all be in the - same domain, i.e. have the same context. However, device_copy call node + same domain, i.e. having the same context. However, device_copy call node needs to be handled different as it copies data from one device to another. """ @@ -330,24 +384,102 @@ def visit_call(self, call): self.unify(device, self.device_for(call.op)) self.unify(device, self.device_for(call.op.body)) self.visit(call.op) + elif isinstance(call.op, _expr.GlobalVar): + device = self.device_for(call) + assert self.mod, "Cannot analyze context on a globalvar without module" + func = self.mod[call.op] + + assert len(call.args) == len(func.params) + + for arg, param in zip(call.args, func.params): + self.visit(arg) + # Save the the arg to function mapping for closures as it will + # be invoked/unified later. + if isinstance(arg.checked_type, ty.FuncType): + assert arg in self.closures + self.closures[param] = self.closures[arg] + self.unify(self.device_for(arg), self.device_for(param)) + + device = self.unify(device, self.device_for(call.op)) + device = self.unify(device, self.device_for(func)) + device = self.unify(device, self.device_for(func.body)) + # Step into the callee. We need to skip recursive calls, otherwise, it + # would be a infinite loop, so does mutual recursive calls + cur_func = self.current_func + self.current_func = call.op + if cur_func.name_hint != call.op.name_hint: + self.visit(func) + self.current_func = cur_func + elif isinstance(call.op, _expr.Var): + # It is a closure when we call a var + # Unify the corresponding arguement and parameter + device = self.device_for(call) + assert call.op in self.closures, f"Cannot find {call.op}" + glb_var = self.closures[call.op] + assert self.mod, "Cannot analyze context on a globalvar without module" + func = self.mod[glb_var] + # Unify the underlying function for clousre or currying funcitons. + while is_closure(func) or (isinstance(func.body, _expr.Let) and + func.body.var in self.closures): + device = self.unify(device, self.device_for(func)) + if is_closure(func): + func = func.body + elif (isinstance(func.body, _expr.Let) and func.body.var in self.closures): + func = self.mod[self.closures[func.body.var]] + + assert isinstance(func, Function) + assert len(call.args) == len(func.params) + + for dev_arg, dev_param in zip(call.args, func.params): + self.visit(dev_arg) + self.unify(self.device_for(dev_arg), self.device_for(dev_param)) + + device = self.unify(device, self.device_for(call.op)) + device = self.unify(device, self.device_for(glb_var)) + device = self.unify(device, self.device_for(func)) + cur_func = self.current_func + # Step into the closure. + self.current_func = glb_var + if not tvm.ir.structural_equal(cur_func, glb_var): + self.visit(func) + self.current_func = cur_func else: self.unify_call(call, call.args, [call]) super().visit_call(call) + def _extract_closure(self, expr): + while isinstance(expr, _expr.Let): + expr = expr.value + if isinstance(expr, _expr.GlobalVar): + return expr + elif isinstance(expr, _expr.Call) and isinstance(expr.op, + _expr.GlobalVar): + return expr.op + return None + def visit_let(self, let): while isinstance(let, _expr.Let): + # Save currying/closures since they will be invoked later + if isinstance(let.value.checked_type, ty.FuncType): + gv = self._extract_closure(let) + assert gv + self.closures[let.var] = gv + self.unify(self.device_for(let.var), self.device_for(let.value)) self.unify_expr(let, let.body) - self.visit(let.var) self.visit(let.value) let = let.body + self.visit(let) + def visit_function(self, f): self.unify(self.device_for(f), self.device_for(f.body)) super().visit_function(f) def visit_tuple(self, tup): # We only support tuple with the same of device. + if not tup: + return device = self.device_for(tup[0]) for i in range(1, len(tup)): device = self.unify(device, self.device_for(tup[i])) @@ -357,10 +489,21 @@ def visit_tuple(self, tup): def visit_tuple_getitem(self, t): value = t.tuple_value if isinstance(t.tuple_value, _expr.Tuple): + self.unify(self.device_for(t), self.device_for(value)) value = t.tuple_value[t.index] self.unify(self.device_for(t), self.device_for(value)) super().visit_tuple_getitem(t) + def visit_match(self, m): + # For match node, we unify the value and the rhs of each clause + device = self.unify(self.device_for(m), self.device_for(m.data)) + for c in m.clauses: + device = self.unify(device, self.device_for(c.rhs)) + super().visit_match(m) + + def visit_global_var(self, gv): + self.device_for(gv) + def visit_var(self, var): self.device_for(var) @@ -369,6 +512,7 @@ def visit_constant(self, const): def results(self): """Return the analysis result. + Returns ------- ret : Dict[tvm.relay.Expr, DeviceDomain] @@ -396,23 +540,30 @@ def _annotator(exp): return _annotator -def context_analysis(expr, fallback_device): +def context_analysis(mod, fallback_device): """Perform device context analysis on a given relay program. This requires that the program has already been annotated and rewritten by replacing on device annotations with device copy nodes. + Parameters ---------- - expr : tvm.relay.Expr - The expression for analysis + mod : tvm.IRModule + The IRModule for analysis + fallback_device : tvm.runtime.TVMContext The default device context + Returns ------- ret : Dict[tvm.relay.Expr, [int]] The mapping of each expression to the device context that is represented in a list form as TVMContext is not a runtime object. """ - ca = ContextAnalysis(fallback_device) + assert isinstance(mod, tvm.IRModule) + # TODO(@zhiics) Apply the pass to all functions/entries + entry = mod.get_global_var("main") + ca = ContextAnalysis(mod, entry, fallback_device) + expr = mod[entry] ca.visit(expr) ret = defaultdict(list) for key, val in ca.results().items(): diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 46434461d429..c81d4c51c502 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -84,10 +84,10 @@ register_injective_schedule("left_shift") register_injective_schedule("shape_of") register_injective_schedule("ndarray_size") +register_injective_schedule("device_copy") register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_erf") -register_broadcast_schedule("device_copy") # zeros diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 70f56a489314..937c36e60919 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -217,7 +217,7 @@ def concatenate_shape_func(attrs, inputs, _): return [_concatenate_shape_func(inputs, convert(axis))] @script -def _reshape_shape_func_input_shape(data_shape, newshape, ndim, reverse=True): +def _reshape_shape_func_input_shape(data_shape, newshape, ndim): out = output_tensor((ndim,), "int64") src_idx = 0 dst_idx = 0 @@ -677,15 +677,3 @@ def split_shape_func(attrs, inputs, _): convert(i), convert(indices_or_sections), convert(axis)) for i in range(num_out)] - -@_reg.register_shape_func("contrib_reverse_reshape", False) -def contrib_reverse_reshape_shape_func(attrs, inputs, out_ndims): - newshape = get_const_tuple(attrs.newshape) - print(inputs[0]) - data_shape = reversed(inputs[0]) - newshape = reversed(newshape) - return [_reshape_shape_func_input_shape(data_shape, - convert(newshape), - out_ndims[0])] - - diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 106fb742d749..e851bb93794c 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -21,7 +21,7 @@ import numpy as np import logging -from tvm.ir.transform import PassContext +from tvm.ir.transform import PassContext, module_pass from tvm import nd, container, tir from ..function import Function from ..expr_functor import ExprVisitor, ExprMutator @@ -37,6 +37,8 @@ from ..analysis.context_analysis import ContextAnalysis, mk_analysis_annotator from ..._ffi.runtime_ctypes import TVMContext +# logging.basicConfig(level=logging.DEBUG) + def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) @@ -192,14 +194,14 @@ def emit_shape_func(self, scope, func, new_args): cpu_ctx = nd.cpu(0) for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) + ctx = self.get_context(arg) # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): - ctx = self.get_context(subexp) if ctx.device_type != cpu_ctx.device_type: subexp = self.device_copy(scope, subexp, ctx, cpu_ctx, j) let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), subexp) - sh_of = self.visit(self.shape_of(subexp)) + sh_of = self.visit(self.shape_of(let_in_arg)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) input_pos += 1 @@ -207,7 +209,6 @@ def emit_shape_func(self, scope, func, new_args): # Pass Inputs elif state == 1: new_arg = self.visit(arg) - ctx = self.get_context(arg) if ctx.device_type != cpu_ctx.device_type: new_arg = self.device_copy(scope, new_arg, ctx, cpu_ctx, i) shape_func_ins.append( @@ -221,6 +222,8 @@ def emit_shape_func(self, scope, func, new_args): out_shapes = [] for i, out in enumerate(cfunc.outputs): tt = ty.TensorType(out.shape, out.dtype) + # Put shape func on CPU. This also ensures that everything between + # shape_of and shape_func are on CPU. alloc = self.make_static_allocation(scope, tt, cpu_ctx, i) alloc = scope.let("shape_func_out_{0}".format(i), alloc) out_shapes.append(alloc) @@ -335,40 +338,46 @@ def visit_call(self, call): return super().visit_call(call) -@transform.function_pass(opt_level=0) +@module_pass(opt_level=0) class ManifestAlloc: """The explicit pass wrapper around ManifestAlloc.""" def __init__(self, target_host, targets): self.target_host = target_host self.targets = targets - def transform_function(self, func, mod, _): + def transform_module(self, mod, _): # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") assert isinstance(self.targets, (dict, container.Map)) + cur_func = mod.get_global_var("main") if len(self.targets) > 1: pass_ctx = PassContext.current() if "relay.fallback_device_type" in pass_ctx.config: fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) else: fallback_ctx = cpu(0) - ca = ContextAnalysis(TVMContext(fallback_ctx.device_type, 0)) + ca = ContextAnalysis(mod, cur_func, TVMContext(fallback_ctx.device_type, 0)) else: dev, _ = self.targets.items()[0] - ca = ContextAnalysis(nd.context(dev.value)) + ca = ContextAnalysis(mod, cur_func, nd.context(dev.value)) + func = mod["main"] # We use logger here to help debug. logging.debug("-----BEFORE ANALYSIS-----") - logging.debug(func.astext(False)) + logging.debug(mod.astext(False)) ca.visit(func) logging.debug("-----AFTER ANALYSIS-----") - logging.debug(func.astext(show_meta_data=False, - annotate=mk_analysis_annotator(ca.results()))) - ea = ManifestAllocPass(self.target_host, ca.results()) - func = ea.visit(func) - return func + logging.debug(mod.astext(show_meta_data=False, + annotate=mk_analysis_annotator(ca.results()))) + ca_res = ca.results() + gv_funcs = mod.functions + for gv, f in gv_funcs.items(): + ea = ManifestAllocPass(self.target_host, ca_res) + f = ea.visit(f) + mod.update_func(gv, f) + return mod register_func("relay.transform.ManifestAlloc", ManifestAlloc) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index d93ee03bdfd8..8f21af9292a9 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -84,10 +84,6 @@ def grow( self.alignment = alignment if self.ctx: - print(self.ctx.device_type) - print(ctx.device_type) - print(self.ctx.device_id) - print(ctx.device_id) assert (self.ctx.device_type == ctx.device_type and self.ctx.device_id == ctx.device_id), "must have matching context" else: @@ -286,10 +282,6 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): dynamic_regions.append(lhs) region = self.current_region(dtype) - if region.ctx and (region.ctx.device_type != ctx.device_type or \ - region.ctx.device_id != ctx.device_id): - return lhs, call - region.grow(lhs, size, alignment, ctx, dtype) return lhs, region.var diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32..66bce0f6b882 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -448,6 +448,9 @@ TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule mod->Update(from); }); +TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") + .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); + TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b89911d60227..4ff0073eb310 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -52,8 +52,6 @@ namespace tvm { namespace relay { -using ExprDeviceMap = std::unordered_map; - namespace transform { Pass LambdaLift(); @@ -65,11 +63,11 @@ Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) { return (*f)(target_host, targets); } -ExprDeviceMap ContextAnalysis(Expr expr, TVMContext default_device) { +vm::ExprDeviceMap ContextAnalysis(IRModule mod, TVMContext default_device) { auto f = tvm::runtime::Registry::Get("relay.analysis.ContextAnalysis"); CHECK(f != nullptr) << "could not load context analysis pass"; - Map > m = (*f)(expr, default_device); - ExprDeviceMap ret; + Map> m = (*f)(mod, default_device); + vm::ExprDeviceMap ret; for (const auto& it : m) { TVMContext ctx; Array ints = it.second; @@ -270,34 +268,20 @@ int GetFallbackDevice() { class VMFunctionCompiler : ExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host, + ExprDeviceMap expr_device_map) : last_register_(0), registers_num_(0), engine_(CompileEngine::Global()), context_(context), - target_host_(target_host) { + target_host_(target_host), + expr_device_map_(std::move(expr_device_map)) { for (const auto& it : targets) { targets_[it.first->value] = it.second; } } VMFunction Compile(const GlobalVar& var, const Function& func) { - // Collect the annotated device information. - // This indicates which device each Relay expr should be executed on. - TVMContext default_device; - if (targets_.size() > 1) { - int fallback_dev = GetFallbackDevice(); - default_device.device_type = static_cast(fallback_dev); - default_device.device_id = 0; - expr_device_map_ = transform::ContextAnalysis(func, default_device); - } else { - default_device.device_type = static_cast((targets_.begin())->first); - if (default_device.device_type != kDLCPU) { - default_device.device_id = 0; - expr_device_map_ = transform::ContextAnalysis(func, default_device); - } - } - size_t i = 0; // We then assign register num to the free variables for (auto param : func->params) { @@ -964,11 +948,15 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // the global state. exec_->functions.resize(context_.module->functions.size()); + // Collect the annotated device information. + // This indicates which device each Relay expr should be executed on. + ExprDeviceMap expr_device_map = AnalyzeContext(); + for (auto named_func : context_.module->functions) { auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_, expr_device_map); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -1187,6 +1175,25 @@ void VMCompiler::Codegen() { } } +ExprDeviceMap VMCompiler::AnalyzeContext() const { + TVMContext default_device; + ExprDeviceMap expr_device_map; + if (targets_.size() > 1) { + int fallback_dev = GetFallbackDevice(); + default_device.device_type = static_cast(fallback_dev); + default_device.device_id = 0; + expr_device_map = transform::ContextAnalysis(context_.module, default_device); + } else { + const auto& tgt = targets_.begin(); + default_device.device_type = static_cast((*tgt).first->value); + if (default_device.device_type != kDLCPU) { + default_device.device_id = 0; + expr_device_map = transform::ContextAnalysis(context_.module, default_device); + } + } + return expr_device_map; +} + runtime::Module CreateVMCompiler() { auto exec = make_object(); return runtime::Module(exec); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 78cff7ba8d4f..b2c8016cc7d3 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,6 +62,7 @@ using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; using TargetsMap = Map; +using ExprDeviceMap = std::unordered_map; struct VMCompilerContext { // The module context for the compilation @@ -105,7 +106,7 @@ class VMCompiler : public runtime::ModuleNode { * * \param mod Relay Module * \param targets For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + * to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); @@ -114,11 +115,28 @@ class VMCompiler : public runtime::ModuleNode { void Codegen(); protected: + /* + * \brief Perform a series of optimizations on the input IR module. + * + * \param mod The input IRModule. + * \param targets For heterogeneous compilation, it is a dictionary indicating context + * to target mapping. For homogeneous compilation, it is a build target. + * \param target_host Host compilation target. + * + * \return The optimized IRModule. + */ IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets, const Target& target_host); + /*! + * \brief Populate the global function names in a map where the value is used + * as the index by the VMFunctions. + */ void PopulateGlobalMap(); + /*! \brief Analyze the device context of each expression. */ + ExprDeviceMap AnalyzeContext() const; + protected: /*! \brief Target devices. */ TargetsMap targets_; diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index 754858fb5d0e..78972beb1ed2 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -623,8 +623,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::DeviceCopy: { - os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.src_device_type << " " - << instr.dst_device_type; + os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.dst_device_type << " " + << instr.src_device_type; break; } default: diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b784f3a12737..cf12670bca3c 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -68,8 +68,17 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { if (nd_array->ctx.device_type != ctx.device_type) { return nd_array.CopyTo(ctx); } + return src; + } else { + CHECK(src->IsInstance()) + << "VM data must be NDArray or a list of NDArray, but received: " << src->_type_key; + std::vector ret; + ADT adt = Downcast(src); + for (size_t i = 0; i < adt.size(); i++) { + ret.push_back(CopyTo(adt[i], ctx)); + } + return ADT(0, ret.begin(), ret.end()); } - return src; } std::vector ToShape(NDArray shape_tensor) { @@ -151,9 +160,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { TVMContext ctx; - int device_type = vm_func.params_device_type[i-1]; + int device_type = vm_func.params_device_type[i - 1]; ctx.device_type = DLDeviceType(device_type); - //TODO(zhiics) How to decide which device id? + // TODO(zhiics) Use virtual device id ctx.device_id = 0; ObjectRef obj = CopyTo(args[i], ctx); func_args[i - 1] = obj; @@ -338,7 +347,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = code_[this->pc_]; - DLOG(INFO) << "Executing(" << pc_ << "): " << instr; + LOG(INFO) << "Executing(" << pc_ << "): " << instr; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index e3eb4e8d654f..38ceb2e92e55 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -34,9 +34,13 @@ def any_dims(ndim): shape.append(relay.Any()) return tuple(shape) -def check_result(args, mod, expected, flatten=False, assert_shape=False): - for kind in ["vm"]: +def check_result(args, mod, expected, flatten=False, assert_shape=False, + only_vm=False): + for kind in ["debug", "vm"]: for tgt, ctx in ctx_list(): + if kind == "debug" and (only_vm or ctx.device_type != + tvm.cpu().device_type): + continue ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt) result = ex.evaluate()(*args) result = result.asnumpy() @@ -172,13 +176,7 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha y = relay.reshape(relu_x, newshape=newshape) mod = tvm.IRModule() mod["main"] = relay.Function(params, y) - # check_result(args, mod, data, flatten=True) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*args).asnumpy() - assert result.shape == out_shape - tvm.testing.assert_allclose(result.flatten(), data.flatten()) + check_result(args, mod, data, flatten=True) def test_any_reshape(): for variable_newshape in [False, True]: @@ -196,7 +194,14 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): mod["main"] = relay.Function([x], y) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) expected = np.argwhere(data) - check_result([data], mod, expected, flatten=True) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data).asnumpy() + assert result.shape == expected.shape + tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + + # TODO(@zhiics) argwhere gpu schedule is currently not avaiable + # check_result([data], mod, expected, flatten=True) def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) @@ -331,7 +336,6 @@ def verify_any_transpose(data_shape, axes, static_data_shape): def test_any_transpose(): verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2)) verify_any_transpose(any_dims(3), None, (2, 3, 4)) - # TODO(@zhiics) This test hangs, debug verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17)) verify_any_transpose(any_dims(2), (-1, 0), (3, 2)) @@ -450,7 +454,6 @@ def test_any_batch_flatten(): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) ref_out_shape = (3, 30) - # TODO(@zhiics) Check dense schedule check_result([data_np], mod, ref_out_shape, assert_shape=True) def verify_any_dense(data_shape, weight_shape, units, static_data_shape, @@ -541,8 +544,13 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): else: ref_out = sorted[0:kval] - # TODO(@zhiics) check topk cuda schedule - check_result(in_vals, mod, ref_out) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(*in_vals) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + + # TODO(@zhiics) Fix topk cuda schedule for dynamic inputs + # check_result(in_vals, mod, ref_out) def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") @@ -602,7 +610,7 @@ def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, strides=strides, slice_mode=slice_mode) mod["main"] = relay.Function(args, y) - check_result(*np_inputs, mod, ref_res) + check_result(np_inputs, mod, ref_res) def test_any_strided_slice(): verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21)) @@ -725,8 +733,8 @@ def test_mixed_input_type(): data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 4) - # TODO(@zhiics) FAILED - check_result([data_np0, data_np0], mod, ref_out_shape, assert_shape=True) + check_result([[[data_np0, data_np0], data_np0], data_np1], mod, + ref_out_shape, assert_shape=True, only_vm=True) def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, layout, static_boxes, static_box_indices_shape, ref_out_shape): diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index aac8f3529be5..710025aeadb3 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -661,7 +661,7 @@ def test_vm_reshape_tensor(): assert "reshape_tensor" in exec.bytecode check_result([x_np], x_np.reshape([4, 4, 8]), mod) - x = relay.var("x", shape=(tvm.tir.Any(), 16), dtype="float32") + x = relay.var("x", shape=(8, 16), dtype="float32") y = relay.reshape(x, [16, -1]) y = relay.reverse_reshape(y, [-1, 4, 0]) mod = tvm.IRModule() From 5e67950c41ef7e8fb929c526e50099178e08f195 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 17 Aug 2020 06:06:06 +0000 Subject: [PATCH 03/21] fix profiler --- src/runtime/vm/profiler/vm.cc | 15 ++++++++++---- src/runtime/vm/vm.cc | 2 +- .../python/relay/benchmarking/benchmark_vm.py | 19 +++++++++--------- .../unittest/test_runtime_vm_profiler.py | 20 +++++++++---------- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 32aedc527e24..63001634558e 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -106,10 +106,17 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun Index output_size, const std::vector& args) { CHECK(exec_); CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; - // TODO(@zhiics) Need to record the device type of each packed func so that - // we can correctly sync. - Index fallback_device_type = static_cast(ctxs_[0].device_type); - auto ctx = this->GetContext(fallback_device_type); + // The device context of any input of the operator is used for + // synchronization. + CHECK_GT(arg_count, 0U); + ObjectRef arg = args[0]; + while (arg->IsInstance()) { + ADT adt = Downcast(arg); + arg = adt[0]; + } + CHECK(arg->IsInstance()); + auto nd_array = Downcast(arg); + auto ctx = nd_array->ctx; TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index cf12670bca3c..df33f2ddff61 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -347,7 +347,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = code_[this->pc_]; - LOG(INFO) << "Executing(" << pc_ << "): " << instr; + DLOG(INFO) << "Executing(" << pc_ << "): " << instr; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index 80e9e4141c1d..073ad6a4ca05 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -67,8 +67,8 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', if measure: print("Evaluate vm inference cost of {} on {}".format(model, repr(ctx))) - ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number, - repeat=repeat) + ftimer = rly_vm.module.time_evaluator("invoke", ctx, number=number, + repeat=repeat) # Measure in millisecond. prof_res = np.array(ftimer("main", data).results) * 1000 print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" % @@ -78,14 +78,13 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', # random input data = np.random.uniform(size=data_shape).astype(dtype) - target = "llvm" - ctx = tvm.cpu(0) - - tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)), - params, target, ctx, dtype) - vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, - target, ctx, dtype) - tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) + + for target, ctx in testing.ctx_list(): + tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)), + params, target, ctx, dtype) + vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) def test_mlp(): diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 97b54c6c5505..cbcb022589e6 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -16,25 +16,23 @@ # under the License. import numpy as np -import tvm -from tvm import te from tvm.runtime import profiler_vm from tvm import relay -from tvm.relay.testing import resnet +from tvm.relay.testing import resnet, ctx_list def test_basic(): mod, params = resnet.get_workload() - target = 'llvm' - ctx = tvm.cpu() if not profiler_vm.enabled(): return - exe = relay.vm.compile(mod, target, params=params) - vm = profiler_vm.VirtualMachineProfiler(exe, ctx) - data = np.random.rand(1, 3, 224, 224).astype('float32') - res = vm.invoke("main", [data]) - print("\n{}".format(vm.get_stat())) - print("\n{}".format(vm.get_stat(False))) + for target, ctx in ctx_list(): + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, ctx) + + data = np.random.rand(1, 3, 224, 224).astype('float32') + res = vm.invoke("main", [data]) + print("\n{}".format(vm.get_stat())) + print("\n{}".format(vm.get_stat(False))) if __name__ == "__main__": test_basic() From 534e9c4e6191a3972e0bbc60197fca93f71101f8 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 17 Aug 2020 22:47:56 +0000 Subject: [PATCH 04/21] fix memory plan --- python/tvm/relay/analysis/context_analysis.py | 2 +- python/tvm/relay/transform/memory_plan.py | 5 ++++ src/relay/backend/vm/compiler.cc | 24 +++++++++---------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index e74badbfc8f7..517f59804cc6 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -294,7 +294,7 @@ def unify_call(self, call_op, inputs, outputs, device=None): ---- For most call nodes, the op, inputs, and outputs should all be in the same domain, i.e. having the same context. However, device_copy call node - needs to be handled different as it copies data from one device to + needs to be handled differently as it copies data from one device to another. """ device = device if device else bottom() diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 8f21af9292a9..a50f49279770 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -280,6 +280,11 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): if not isinstance(size, expr.Constant): self.enter_scope() dynamic_regions.append(lhs) + else: + region = self.current_region(dtype) + if region.ctx and region.ctx.device_type != ctx.device_type: + self.enter_scope() + dynamic_regions.append(lhs) region = self.current_region(dtype) region.grow(lhs, size, alignment, ctx, dtype) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4ff0073eb310..a71fc8e96aae 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1006,23 +1006,23 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); - // // Fuse the shape functions. - // pass_seqs.push_back(transform::FuseOps()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); - // // Perform memory planning in order to coalesce/reduce allocations. - // pass_seqs.push_back(transform::MemoryPlan()); + // Perform memory planning in order to coalesce/reduce allocations. + pass_seqs.push_back(transform::MemoryPlan()); - // // Compute away constant computation introduced by coalescing allocations. - // pass_seqs.push_back(transform::FoldConstant()); + // Compute away constant computation introduced by coalescing allocations. + pass_seqs.push_back(transform::FoldConstant()); - // // Fuse the shape functions. - // pass_seqs.push_back(transform::FuseOps()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); - // // Create allocations for math introduced by dynamic region math. - // pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + // Create allocations for math introduced by dynamic region math. + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); - // // Compute away possibly introduced constant computation. - // pass_seqs.push_back(transform::FoldConstant()); + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); // Lift constants to the top-level of the block to simplify VM code generation. // TODO(@icemelon9, @jroesch): Remove this pass for now because some From f4a903ec5b23a283b57f8479d32e15a650634937 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 19 Aug 2020 01:31:30 +0000 Subject: [PATCH 05/21] add more unification --- python/tvm/relay/analysis/context_analysis.py | 46 +++++++++---------- python/tvm/relay/transform/memory_alloc.py | 38 +++++++++------ 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index 517f59804cc6..7c651e122bc0 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -331,22 +331,19 @@ def visit_call(self, call): self.device_copy(inps, outs, src_dev_type, dst_dev_type) super().visit_call(call) elif call.op == op.op.get("memory.alloc_storage"): + # The arguments should be on CPU. + for arg in (call.args[0], call.args[1]): + self.unify(self.device_for(arg), device_type(cpu(0))) + self.visit(arg) call_dev = device_type(TVMContext(call.attrs.device_type, call.attrs.device_id)) self.unify(self.device_for(call), call_dev) - # The arguments should be one the same device as the call. - self.visit(call.args[0]) - size = call.args[0] - self.visit(call.args[1]) - alignment = call.args[1] - self.unify(self.device_for(size), call_dev) - self.unify(self.device_for(alignment), call_dev) elif call.op == op.op.get("memory.alloc_tensor"): - storage = call.args[0] - shape = call.args[1] - self.visit(call.args[1]) + storage, shape = call.args[0], call.args[1] self.unify(self.device_for(storage), self.device_for(call)) - self.unify(self.device_for(shape), self.device_for(call)) + # The shape for alloc_tensor should be on CPU. + self.unify(self.device_for(shape), device_type(cpu(0))) + self.visit(shape) elif call.op == op.op.get("vm.shape_func"): shape_func_domain = device_type(cpu(0)) # No need to union the op of a shape_func as shape_func doesn't @@ -359,18 +356,21 @@ def visit_call(self, call): for arg in call.args[2]: self.visit(arg) elif call.op == op.op.get("vm.invoke_tvm_op"): - if isinstance(call.args[0].body, _expr.Call) and \ - call.args[0].body.op == op.op.get("device_copy"): - input_tensor = call.args[1] - output_tensor = call.args[2] - self.device_copy(input_tensor, output_tensor, - call.attrs.src_dev_type, - call.attrs.dst_dev_type) - else: - device = self.unify_call(call.args[0], call.args[1].fields, - call.args[2].fields) - self.unify(self.device_for(call), device) - super().visit_call(call) + device = self.unify_call(call.args[0], call.args[1].fields, + call.args[2].fields) + self.unify(self.device_for(call), device) + super().visit_call(call) + elif call.op == op.op.get("vm.shape_of"): + # vm shape_of is always on the CPU. + self.visit(call.args[0]) + self.unify(self.device_for(call), device_type(cpu(0))) + elif call.op == op.op.get("vm.reshape_tensor"): + data, shape = call.args + self.unify(self.device_for(call), self.device_for(data)) + # The shape field of reshape_tensor is always on the CPU. + self.unify(self.device_for(shape), device_type(cpu(0))) + self.visit(data) + self.visit(shape) elif isinstance(call.op, Function): device = self.device_for(call) for arg in call.args: diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index e851bb93794c..94efaf813451 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -99,6 +99,7 @@ def __init__(self, target_host, context_analysis): self.default_context = cpu(0) self.compute_dtype = "int64" self.context_analysis = context_analysis + self.cached_var = {} super().__init__() def get_context(self, expr): @@ -172,6 +173,7 @@ def visit_let(self, let): self.scopes.append(scope) while isinstance(let, expr.Let): + self.cached_var[let.var] = let.value new_val = self.visit(let.value) scope.let(let.var, new_val) let = let.body @@ -182,6 +184,13 @@ def visit_let(self, let): return scope.get() + def skip_copy_input(self, var): + """Check if device copy for the input should be skipped. We currently + skip copying it when the input is a constant. + """ + return (var in self.cached_var and isinstance(self.cached_var[var], + expr.Constant)) + def emit_shape_func(self, scope, func, new_args): """Insert the shape function given a primitive function.""" shape_func_ins = [] @@ -198,10 +207,20 @@ def emit_shape_func(self, scope, func, new_args): # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): - if ctx.device_type != cpu_ctx.device_type: + # Note vm.shape_of is always executed on CPU. The input may + # need to have a copy as it is likely passed from other + # callers and the argument was on other devices. Constants + # can leave on CPU. + # + # However, this may cause an unnecessary copy when + # all dynamic inputs can be assigned to CPU. An additional + # pass may be needed to decide/remove this copy. + if not self.skip_copy_input(subexp) and \ + ctx.device_type != cpu_ctx.device_type: subexp = self.device_copy(scope, subexp, ctx, cpu_ctx, j) - let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), subexp) - sh_of = self.visit(self.shape_of(let_in_arg)) + subexp = scope.let("in_arg_{0}".format(input_pos + j), + subexp) + sh_of = self.visit(self.shape_of(subexp)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) input_pos += 1 @@ -241,22 +260,16 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) storages = [] - cpu_ctx = nd.cpu(0) func_ctx = self.get_context(func) - copy_out_shapes = [] for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)): size = self.compute_storage_in_relay(out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) - if func_ctx.device_type != cpu_ctx.device_type: - size = self.device_copy(scope, size, cpu_ctx, func_ctx, i) - out_shape = self.device_copy(scope, out_shape, cpu_ctx, func_ctx, i) - copy_out_shapes.append(out_shape) sto = scope.let("storage_{i}".format(i=i), alloc_storage( size, alignment, func_ctx, out_type.dtype)) storages.append(sto) outs = [] - sh_ty_storage = zip(copy_out_shapes, out_types, storages) + sh_ty_storage = zip(out_shapes, out_types, storages) for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): alloc = alloc_tensor( storage, @@ -276,11 +289,6 @@ def emit_reshape_tensor(self, scope, func, new_args, ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) shape_expr = out_shapes[0] inp = new_args[0] - inp_ctx = self.get_context(func) - cpu_ctx = nd.cpu(0) - if inp_ctx.device_type != cpu_ctx.device_type: - shape_expr = self.device_copy(scope, shape_expr, cpu_ctx, - inp_ctx, 0) ret = self.reshape_tensor(inp, shape_expr, ret_type.shape) return ret else: From 06b3ced9fae3e1d0db53f86c1f398d94c07ccb63 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 19 Aug 2020 05:35:20 +0000 Subject: [PATCH 06/21] add serialization --- src/runtime/vm/executable.cc | 21 +++++++++++++++++++-- src/runtime/vm/serialize_util.h | 12 +++++++++--- src/runtime/vm/vm.cc | 2 ++ tests/python/relay/test_vm_serialization.py | 15 +++++++++++++++ 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index ef2091746795..c95b739a1d44 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -249,6 +251,13 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { for (const auto& it : arrays) { runtime::SaveDLTensor(strm, it); } + + // Save the const to device mapping. + std::vector const_device_type; + for (auto dev_type : this->const_device_type) { + const_device_type.push_back(static_cast(dev_type)); + } + strm->Write(const_device_type); } void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { @@ -448,7 +457,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), - func.params); + func.params, func.params_device_type); func_format.Save(strm); // Serialize each instruction. @@ -515,6 +524,14 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { STREAM_CHECK(constant.Load(strm), "constant"); this->constants.push_back(constant); } + + // Load the const to device mapping. + std::vector const_device_type; + STREAM_CHECK(strm->Read(&const_device_type), "constant"); + CHECK_EQ(size, const_device_type.size()); + for (auto dev : const_device_type) { + this->const_device_type.push_back(static_cast(dev)); + } } void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { @@ -746,7 +763,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { // Create the VM function. VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, - loaded_func.register_file_size); + loaded_func.register_file_size, loaded_func.params_device_type); auto it = this->global_map.find(loaded_func.name); CHECK(it != this->global_map.end()); CHECK_LE(it->second, this->global_map.size()); diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index d52b73d81a78..25bdf09e0b5a 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -57,15 +57,18 @@ struct VMFunctionSerializer { size_t num_instructions; /*! \brief The parameters of the VMFunction. */ std::vector params; + /*! \brief The device type of each parameter of the VMFunction. */ + std::vector params_device_type; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, - const std::vector& params) + const std::vector& params, const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), - params(params) {} + params(params), + params_device_type(params_device_type) {} /*! * \brief Load the serialized function header. @@ -81,7 +84,9 @@ struct VMFunctionSerializer { register_file_size = std::stoll(func_info[1]); // Get the number of instructions. num_instructions = static_cast(std::stoll(func_info[2])); - return strm->Read(¶ms); + if (!strm->Read(¶ms)) return false; + if (!strm->Read(¶ms_device_type)) return false; + return true; } /*! @@ -95,6 +100,7 @@ struct VMFunctionSerializer { func_info.push_back(std::to_string(num_instructions)); strm->Write(func_info); strm->Write(params); + strm->Write(params_device_type); } }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index df33f2ddff61..06a706e6ac7b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -157,6 +157,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const auto& param_names = vm_func.params; CHECK_EQ(args.size() - 1, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; + CHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { TVMContext ctx; diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index df3bbc19cb58..6486b707fcec 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -298,5 +298,20 @@ def test_vm_shape_of(): tvm.testing.assert_allclose(res.flatten(), data.flatten()) +def test_dynamic_bcast(): + dtype = 'float32' + x = relay.var('x', shape=(relay.Any(), 2), dtype=dtype) + y = relay.var('y', shape=(3, 2), dtype=dtype) + mod = tvm.IRModule() + mod['main'] = relay.Function([x, y], relay.add(x, y)) + x_data = np.random.uniform(size=(1, 2)).astype(dtype) + y_data = np.random.uniform(size=(3, 2)).astype(dtype) + res_np = np.add(x_data, y_data) + for target, ctx in testing.ctx_list(): + res = get_serialized_output(mod, *(x_data, y_data), target=target, + ctx=ctx) + tvm.testing.assert_allclose(res.asnumpy(), res_np) + + if __name__ == "__main__": pytest.main([__file__]) From 19fbe6feb22b5c2340b2c4e676a3a96b9275908c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 21 Aug 2020 18:59:17 +0000 Subject: [PATCH 07/21] add gpu tests for test_adt --- python/tvm/relay/analysis/context_analysis.py | 11 +- python/tvm/relay/transform/memory_alloc.py | 7 +- tests/python/relay/test_adt.py | 16 +- .../relay/test_pass_context_analysis.py | 140 ++++++++---------- 4 files changed, 86 insertions(+), 88 deletions(-) diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index 7c651e122bc0..2d8cb778e3d5 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -396,8 +396,11 @@ def visit_call(self, call): # Save the the arg to function mapping for closures as it will # be invoked/unified later. if isinstance(arg.checked_type, ty.FuncType): - assert arg in self.closures - self.closures[param] = self.closures[arg] + if arg in self.closures: + self.closures[param] = self.closures[arg] + else: + assert isinstance(arg, _expr.GlobalVar) + self.closures[param] = arg self.unify(self.device_for(arg), self.device_for(param)) device = self.unify(device, self.device_for(call.op)) @@ -474,7 +477,9 @@ def visit_let(self, let): def visit_function(self, f): self.unify(self.device_for(f), self.device_for(f.body)) - super().visit_function(f) + for x in f.params: + self.device_for(x) + self.visit(f.body) def visit_tuple(self, tup): # We only support tuple with the same of device. diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 94efaf813451..c78d57679047 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -37,7 +37,7 @@ from ..analysis.context_analysis import ContextAnalysis, mk_analysis_annotator from ..._ffi.runtime_ctypes import TVMContext -# logging.basicConfig(level=logging.DEBUG) +logging.basicConfig(level=logging.DEBUG) def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") @@ -368,7 +368,10 @@ def transform_module(self, mod, _): fallback_ctx = cpu(0) ca = ContextAnalysis(mod, cur_func, TVMContext(fallback_ctx.device_type, 0)) else: - dev, _ = self.targets.items()[0] + if isinstance(self.targets, dict): + dev = list(self.targets.keys())[0] + else: + dev, _ = self.targets.items()[0] ca = ContextAnalysis(mod, cur_func, nd.context(dev.value)) func = mod["main"] diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index ff76e1c64bcb..48f2292140b1 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay +from tvm.relay import testing from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor from tvm.relay.prelude import Prelude, StaticTensorArrayOps @@ -719,13 +719,15 @@ def test_iterate(): assert count(res) == 12 -def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", - ta_ctx=tvm.cpu(), target="llvm", rtol=1e-5): +def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=ta_mod, ctx=ta_ctx, target=target) - result = ex.evaluate()(*args) - got = vmobj_to_list(result, dtype) - tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) + for target, ctx in testing.ctx_list(): + if kind == "debug" and ctx.device_type != tvm.cpu().device_type: + continue + ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target) + result = ex.evaluate()(*args) + got = vmobj_to_list(result, dtype) + tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) def test_tensor_expand_dims(): diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py index a133a8cbf458..0788a1bdd7c1 100644 --- a/tests/python/relay/test_pass_context_analysis.py +++ b/tests/python/relay/test_pass_context_analysis.py @@ -20,80 +20,68 @@ import tvm from tvm import relay +from tvm.relay import expr as _expr, transform +from tvm.relay.analysis.context_analysis import ContextAnalysis, mk_analysis_annotator -data0 = relay.var("data0", shape=(1, relay.Any())) -data1 = relay.var("data1", shape=(1, relay.Any())) - -r0 = relay.cast(data0, dtype="int32") -w0 = relay.const(np.ndarray(shape=(30522, 768), dtype="float32")) -r1 = relay.take(w0, r0, axis=0) -r2 = relay.cast(data1, dtype="int32") -w1 = relay.const(np.ndarray(shape=(2, 768), dtype="float32")) -r3 = relay.take(w1, r2, axis=0) -r4 = relay.add(r1, r3) -r5 = relay.transpose(r4, axes=[1, 0, 2]) -r6 = relay.shape_of(r5, dtype="int32") -r7 = relay.take(r6, relay.const(0, dtype="int32")) -r8 = relay.cast(r7, dtype="float32") -r9 = relay.multiply(relay.const(1, dtype="float32"), r8) -r10 = relay.add(relay.const(0, dtype="float32"), r9) -r11 = relay.arange(relay.const(0, dtype="float32"), r10,\ - relay.const(1, dtype="float32"), dtype="float32") -r12 = relay.cast(r11, dtype="int32") -w2 = relay.const(np.ndarray(shape=(512, 768), dtype="float32")) -r13 = relay.take(w2, r12, axis=0) -r14 = relay.expand_dims(r13, axis=1) -r15 = relay.add(r5, r14) -r16 = relay.nn.dropout(r15, rate=0.1) -# r17 = relay.TupleGetItem(r16.astuple(), 0) -w3 = relay.const(np.ndarray(shape=(768,), dtype="float32")) -w4 = relay.const(np.ndarray(shape=(768,), dtype="float32")) -r18 = relay.nn.layer_norm(r16, w3, w4, epsilon=1e-12) -r19 = relay.op.reverse_reshape(r18, newshape=[-1, 0]) -w5 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) -r20 = relay.reverse_reshape(w5, newshape=[12, -1, 0]) -w6 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) -r21 = relay.reverse_reshape(w6, newshape=[12, -1, 0]) -w7 = relay.const(np.ndarray(shape=(768, 768), dtype="float32")) -r22 = relay.reverse_reshape(w7, newshape=[12, -1, 0]) -r23 = relay.Tuple([r20, r21, r22]) -r24 = relay.concatenate(r23, axis=-2) -r25 = relay.reverse_reshape(r24, newshape=[-1, 0]) -r26 = relay.nn.dense(r19, r25, units=2304) -w8 = relay.const(np.ndarray(shape=(768,), dtype="float32")) -w9 = relay.const(np.ndarray(shape=(768,), dtype="float32")) -w10 = relay.const(np.ndarray(shape=(768,), dtype="float32")) -r27 = relay.Tuple([w8, w9, w10]) -r28 = relay.concatenate(r27, axis=0) -r29 = relay.nn.bias_add(r26, r28, axis=-1) -r30 = relay.reshape(r29, newshape=[-1, 1, 2304]) -r31 = relay.reshape(r30, newshape=[0, 0, 12, 3, -1]) -r32 = relay.take(r31, relay.const(0, dtype="int64"), axis=3) -r33 = relay.transpose(r32, axes=[1, 2, 0, 3]) -r34 = relay.reverse_reshape(r33, newshape=[-1, 0, 0]) -r35 = relay.shape_of(r34, dtype="int32") -r36 = relay.take(r35, relay.const(2, dtype="int32")) -r37 = relay.cast(r36, dtype="float32") -r38 = relay.sqrt(r37) -r39 = relay.divide(r34, r38) -r40 = relay.take(r31, relay.const(1, dtype="int64"), axis=3) -r41 = relay.transpose(r40, axes=[1, 2, 0, 3]) -r42 = relay.reverse_reshape(r41, newshape=[-1, 0, 0]) -r43 = relay.nn.batch_matmul(r39, r42) -# r44 = relay.nn.softmax(r43) - -func = relay.Function([data0, data1], r43) -mod = tvm.ir.IRModule.from_expr(func) - -params = {} -exe = relay.vm.compile(mod, target="cuda", params=params) -rt = tvm.runtime.vm.VirtualMachine(exe, tvm.gpu(0)) - -seq_length = 128 -d0 = np.random.randint(0, 1000, size=(1, seq_length)).astype('float32') -d1 = np.ones((1, seq_length)).astype('float32') -d2 = np.asarray([seq_length]).astype('float32') - -rt.set_input("main", data0=d0, data1=d1) - -rt.run() + +def test_device_copy(): + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + x = relay.var("x", shape=(2, 3)) + copy = relay.op.device_copy(x, tvm.cpu(), tvm.gpu()) + out = copy + relay.const(np.random.rand(2, 3)) + glb_var = relay.GlobalVar("main") + mod[glb_var] = relay.Function([x], out) + ca = ContextAnalysis(mod, glb_var, tvm.cpu()) + ca.visit(mod[glb_var]) + ca_res = ca.results() + + for expr, dev in ca_res.items(): + if isinstance(expr, _expr.Call): + assert dev.device_type == tvm.gpu().device_type + elif isinstance(expr, _expr.Var): + assert dev.device_type == tvm.cpu().device_type + elif isinstance(expr, _expr.Constant): + assert dev.device_type == tvm.gpu().device_type + + +def test_shape_func(): + pass + + +def test_vm_shape_of(): + pass + + +def test_alloc_storage(): + pass + + +def test_alloc_tensor(): + pass + + +def test_dynamic_input(): + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + mod = tvm.IRModule() + dtype = "float32" + data_shape = (relay.Any(), 4) + tensor_type = relay.TensorType(data_shape, dtype) + tuple_type = relay.TupleType([tensor_type, tensor_type]) + data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type])) + data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype) + data_tuple = relay.expr.TupleWrapper(data0, 2) + nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2) + y = nested_data_tuple[1] * data_tuple[1] + data1 + mod["main"] = relay.Function([data0, data1], y) + compiler = relay.vm.VMCompiler() + entry = mod.get_global_var("main") + ca = ContextAnalysis(mod, entry, tvm.cpu()) + ca.visit(mod[entry]) + ca_res = ca.results() + +if __name__ == "__main__": + pass From 06fa26ff024a5401439c0c1ae50bb987201fd961 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 22 Aug 2020 05:31:20 +0000 Subject: [PATCH 08/21] cache visited functions --- python/tvm/relay/analysis/context_analysis.py | 11 +++++++++-- python/tvm/relay/transform/memory_alloc.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index 2d8cb778e3d5..401de8374918 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -165,6 +165,7 @@ def __init__(self, mod, current_func, fallback_device): self.closures = {} self.current_func = current_func self.fallback_device = fallback_device + self.visited = {} def lookup(self, device): """Find the root domain of a given device domain. @@ -383,7 +384,7 @@ def visit_call(self, call): self.unify(device, self.device_for(call.op)) self.unify(device, self.device_for(call.op.body)) - self.visit(call.op) + # self.visit(call.op) elif isinstance(call.op, _expr.GlobalVar): device = self.device_for(call) assert self.mod, "Cannot analyze context on a globalvar without module" @@ -412,6 +413,7 @@ def visit_call(self, call): self.current_func = call.op if cur_func.name_hint != call.op.name_hint: self.visit(func) + self.visited[func] = device self.current_func = cur_func elif isinstance(call.op, _expr.Var): # It is a closure when we call a var @@ -445,6 +447,7 @@ def visit_call(self, call): self.current_func = glb_var if not tvm.ir.structural_equal(cur_func, glb_var): self.visit(func) + self.visited[func] = device self.current_func = cur_func else: self.unify_call(call, call.args, [call]) @@ -476,10 +479,14 @@ def visit_let(self, let): self.visit(let) def visit_function(self, f): - self.unify(self.device_for(f), self.device_for(f.body)) + if f in self.visited and self.visited[f].domain: + return + + device = self.unify(self.device_for(f), self.device_for(f.body)) for x in f.params: self.device_for(x) self.visit(f.body) + self.visited[f] = device def visit_tuple(self, tup): # We only support tuple with the same of device. diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index c78d57679047..09907163f9ab 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -37,7 +37,7 @@ from ..analysis.context_analysis import ContextAnalysis, mk_analysis_annotator from ..._ffi.runtime_ctypes import TVMContext -logging.basicConfig(level=logging.DEBUG) +# logging.basicConfig(level=logging.DEBUG) def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") From 39217843915c72170613c9f14864e1c5c0cf3c9c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 24 Aug 2020 02:46:54 +0000 Subject: [PATCH 09/21] path compression --- python/tvm/relay/analysis/context_analysis.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index 401de8374918..e8516c7d603e 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -180,7 +180,10 @@ def lookup(self, device): ret : DeviceDomain The root domain. """ - while device in self.device_uf: + while device in self.device_uf and device.domain is None: + # Path compression + if self.device_uf[device] in self.device_uf: + self.device_uf[device] = self.device_uf[self.device_uf[device]] device = self.device_uf[device] return device @@ -319,7 +322,7 @@ def visit_call(self, call): outs.append(call.op) body = call.op.body assert isinstance(body, _expr.Call) and is_device_copy(body) - outs.append(call.op.body) + # outs.append(call.op.body) src_dev_type = call.op.body.attrs.src_dev_type dst_dev_type = call.op.body.attrs.dst_dev_type else: @@ -478,8 +481,12 @@ def visit_let(self, let): self.visit(let) + def is_primitive(self, func): + return hasattr(func, 'attrs') and hasattr(func.attrs, 'Primitive') \ + and int(func.attrs.Primitive) == 1 + def visit_function(self, f): - if f in self.visited and self.visited[f].domain: + if self.is_primitive(f) or (f in self.visited and self.visited[f].domain): return device = self.unify(self.device_for(f), self.device_for(f.body)) From bdd2d81085be78eb0a2d66775e960afd993ddd57 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 24 Aug 2020 23:01:44 +0000 Subject: [PATCH 10/21] C++ context analysis --- python/tvm/relay/analysis/analysis.py | 14 + python/tvm/relay/analysis/context_analysis.py | 2 +- python/tvm/relay/transform/memory_alloc.py | 30 +- src/relay/analysis/context_analysis.cc | 686 ++++++++++++++++++ 4 files changed, 717 insertions(+), 15 deletions(-) create mode 100644 src/relay/analysis/context_analysis.cc diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 99f4252ac4f7..f15413ce9b5d 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -28,6 +28,20 @@ from .feature import Feature +def context_analysis(mod, default_context): + """Analyze the context information of a Relay program. + + Parameters + ---------- + expr : tvm.IRModule + The input module. + + default_context : tvm.runtime.TVMContext + The default context allocated to an IR node. + """ + return _ffi_api.ContextAnalysis(mod, default_context) + + def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py index e8516c7d603e..77565104f60a 100644 --- a/python/tvm/relay/analysis/context_analysis.py +++ b/python/tvm/relay/analysis/context_analysis.py @@ -590,4 +590,4 @@ def context_analysis(mod, fallback_device): return ret -register_func("relay.analysis.ContextAnalysis", context_analysis) +# register_func("relay.analysis.ContextAnalysis", context_analysis) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 09907163f9ab..cc610ce25ce2 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -22,11 +22,10 @@ import logging from tvm.ir.transform import PassContext, module_pass -from tvm import nd, container, tir +from tvm import nd, container from ..function import Function from ..expr_functor import ExprVisitor, ExprMutator from ..scope_builder import ScopeBuilder -from . import transform from .. import op from ... import DataType, register_func from .. import ty, expr @@ -34,7 +33,8 @@ from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu from ..op.memory import alloc_storage -from ..analysis.context_analysis import ContextAnalysis, mk_analysis_annotator +from ..analysis import context_analysis +from ..analysis.context_analysis import mk_analysis_annotator from ..._ffi.runtime_ctypes import TVMContext # logging.basicConfig(level=logging.DEBUG) @@ -358,31 +358,33 @@ def transform_module(self, mod, _): # can we have def pass_init? mod.import_from_std("core.rly") + # We use logger here to help debug. + logging.debug("-----BEFORE CONTEXT ANALYSIS-----") + logging.debug(mod.astext(False)) + assert isinstance(self.targets, (dict, container.Map)) - cur_func = mod.get_global_var("main") if len(self.targets) > 1: pass_ctx = PassContext.current() if "relay.fallback_device_type" in pass_ctx.config: fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) else: fallback_ctx = cpu(0) - ca = ContextAnalysis(mod, cur_func, TVMContext(fallback_ctx.device_type, 0)) + ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) else: if isinstance(self.targets, dict): dev = list(self.targets.keys())[0] else: dev, _ = self.targets.items()[0] - ca = ContextAnalysis(mod, cur_func, nd.context(dev.value)) + ca = context_analysis(mod, nd.context(dev.value)) - func = mod["main"] - # We use logger here to help debug. - logging.debug("-----BEFORE ANALYSIS-----") - logging.debug(mod.astext(False)) - ca.visit(func) - logging.debug("-----AFTER ANALYSIS-----") + # TODO(zhiics) This is not needed after we port the pass to C++ + ca_res = {} + for key, val in ca.items(): + ca_res[key] = TVMContext(val[0].value, val[1].value) + + logging.debug("-----AFTER CONTEXT ANALYSIS-----") logging.debug(mod.astext(show_meta_data=False, - annotate=mk_analysis_annotator(ca.results()))) - ca_res = ca.results() + annotate=mk_analysis_annotator(ca_res))) gv_funcs = mod.functions for gv, f in gv_funcs.items(): ea = ManifestAllocPass(self.target_host, ca_res) diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc new file mode 100644 index 000000000000..5c664bb47db9 --- /dev/null +++ b/src/relay/analysis/context_analysis.cc @@ -0,0 +1,686 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/context_analysis.cc + * \brief A pass for analyzing device attribute of each IR node. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace analysis { + +using PackedAnalysisResultMap = Map>; +using AnalysisResultMap = + std::unordered_map; + +// Cache ops +static const Op& device_copy_op = Op::Get("device_copy"); +static const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); +static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); +static const Op& shape_of_op = Op::Get("vm.shape_of"); +static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); +static const Op& shape_func_of = Op::Get("vm.shape_func"); +static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/* + * \brief A class to represent the device of a domain, i.e. a segment of relay program. + */ +class DeviceDomain { + public: + // Construct an empty domain. + DeviceDomain() { + ctx_.device_type = static_cast(-1); + ctx_.device_id = -1; + } + + // Construct a domain based on a given context. + explicit DeviceDomain(const TVMContext& ctx) : ctx_(ctx) {} + + // Check if the current domain is empty. + bool IsEmptyDomain() const { + return static_cast(ctx_.device_type) == -1 && ctx_.device_id == -1; + } + + // Check if the current domain equals the other one. + bool operator==(const DeviceDomain& other) const { + return ctx_.device_type == other.ctx_.device_type && ctx_.device_id == other.ctx_.device_id; + } + + bool operator!=(const DeviceDomain& other) const { return !(*this == other); } + + private: + // Create a hash for a domain. + struct Hash { + size_t operator()(const DeviceDomainPtr& domain) const { + if (domain->IsEmptyDomain()) { + return (size_t)(domain.get()); + } else { + size_t const h1(std::hash()(static_cast(domain->ctx_.device_type))); + size_t const h2(std::hash()(domain->ctx_.device_id)); + return h1 ^ (h2 << 1); + } + } + }; + + // Create an equality for domains. + struct Equal { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + // We compare the pointer for empty domains. + if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == rhs.get(); + + // Otherwise device type and id are used to check equality. + return (*lhs.get() == *rhs.get()); + } + }; + + /* \brief The device to be assigned to the current domain. */ + TVMContext ctx_; + + friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + friend class ContextAnalyzer; +}; + +// Join two domains. +DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) { + return lhs; + } else if (lhs->IsEmptyDomain()) { + return rhs; + } else if (rhs->IsEmptyDomain()) { + return lhs; + } else { + CHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular device to unify"; + return lhs; + } +} + +/* + * \brief Compute on which device each sub-expression will execute. A union find + * algorithm is used to assign and merge the context domains. + */ +class ContextAnalyzer : public ExprVisitor { + public: + ContextAnalyzer(const IRModule& mod, const GlobalVar& current_func, + const TVMContext& default_context) + : mod_(mod), current_func_(current_func), default_context_(default_context) { + cpu_ctx_.device_type = kDLCPU; + cpu_ctx_.device_id = 0; + } + + // Create an empty domain. + // This usually happens when we enter a new scope, i.e. Function. + DeviceDomainPtr Bottom() { return std::make_shared(DeviceDomain()); } + + // Create a domain with the given device context. + DeviceDomainPtr DeviceType(const TVMContext& ctx) { + return std::make_shared(DeviceDomain(ctx)); + } + + // Find the root of a device. + DeviceDomainPtr Lookup(DeviceDomainPtr device) { + while (device_uf_.count(device) && device != device_uf_[device]) { + // Path compression + if (device_uf_.count(device_uf_[device])) { + device_uf_[device] == device_uf_[device_uf_[device]]; + } + device = device_uf_[device]; + } + return device; + } + + // Unify two domains. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto unified_device = Join(lhs, rhs); + if (lhs != unified_device) { + device_uf_[lhs] = unified_device; + } + + if (rhs != unified_device) { + device_uf_[rhs] = unified_device; + } + + return unified_device; + } + + // Unify the domain for two IR nodes. + DeviceDomainPtr UnifyExpr(const Expr& lhs, const Expr& rhs) { + auto lhs_dom = DeviceFor(lhs); + auto rhs_dom = DeviceFor(rhs); + return Unify(lhs_dom, rhs_dom); + } + + // Lookup or insert an IR node to device domain map. + DeviceDomainPtr DeviceFor(const Expr& expr) { + auto it = expr_to_device_.find(expr); + if (it == expr_to_device_.end()) { + auto bottom = Bottom(); + expr_to_device_[expr] = bottom; + return bottom; + } else { + return it->second; + } + } + + // Unify the device context for a device copy node. Device copy node is + // the only node that carries bidirectional devices in the input program. The device + // attribute of other nodes can be propagated from it. + void UnifyDeviceCopy(const std::vector& inps, const std::vector& outputs, + DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + TVMContext src_ctx; + src_ctx.device_type = src_dev_type; + src_ctx.device_id = 0; + auto src_domain = DeviceType(src_ctx); + for (const auto& it : inps) { + auto lhs = DeviceFor(it); + Unify(lhs, src_domain); + } + + TVMContext dst_ctx; + dst_ctx.device_type = dst_dev_type; + dst_ctx.device_id = 0; + auto dst_domain = DeviceType(dst_ctx); + for (const auto& it : outputs) { + auto lhs = DeviceFor(it); + Unify(lhs, dst_domain); + } + } + + // Unify the domain of inputs and outputs of a relay call. + // + // For most call nodes, the op, inputs, and outputs should all be in the + // same domain, i.e. having the same context. However, device_copy call node + // needs to be handled differently as it copies data from one device to + // another. + DeviceDomainPtr UnifyCall(const Expr& call_op, const Array& inps, + const Array& outputs, DeviceDomainPtr device) { + device = Unify(device, DeviceFor(call_op)); + + for (const auto& it : inps) { + device = Unify(device, DeviceFor(it)); + } + + for (const auto& it : outputs) { + device = Unify(device, DeviceFor(it)); + } + + return device; + } + + void VisitExpr_(const CallNode* cn) final { + Call call = GetRef(cn); + + if (IsDeviceCopy(call)) { + UnifyDeviceCopyCall(cn); + } else if (call->op == alloc_storage_op) { + UnifyAllocStorageCall(cn); + } else if (call->op == alloc_tensor_op) { + UnifyAllocTensorCall(cn); + } else if (call->op == shape_func_of) { + UnifyShapeFuncCall(cn); + } else if (call->op == shape_of_op) { + UnifyShapeOfCall(cn); + } else if (call->op == invoke_tvm_op) { + UnifyInvokeTVMOpCall(cn); + } else if (call->op == reshape_tensor_op) { + UnifyReshapeTensorCall(cn); + } else if (call->op.as()) { + UnifyFunctionCall(cn); + } else if (call->op.as()) { + UnifyGlobalVarCall(cn); + } else if (call->op.as()) { + UnifyVarCall(cn); + } else { + UnifyCall(call, cn->args, {call}, Bottom()); + ExprVisitor::VisitExpr_(cn); + } + } + + void VisitExpr_(const LetNode* ln) final { + Expr expr = GetRef(ln); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Save currying/closures since they will be invoked later + auto ty = let->value->checked_type(); + if (ty->IsInstance()) { + auto gv = ExtractClosure(let); + CHECK(gv.defined() && gv->IsInstance()); + closures_[let->var] = Downcast(gv); + } + + // Unify let var, value, and body + Unify(DeviceFor(let->var), DeviceFor(let->value)); + UnifyExpr(let, let->body); + ExprVisitor::VisitExpr(let->value); + expr = let->body; + } + // Visit the last body + ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* fn) final { + auto func = GetRef(fn); + auto it = visited_.find(func); + // No need to step into fused primitive functions as they are handled as + // a whole. + if (fn->HasNonzeroAttr(attr::kPrimitive) || + (it != visited_.end() && !DeviceFor(func)->IsEmptyDomain())) { + return; + } + + auto device = Unify(DeviceFor(func), DeviceFor(fn->body)); + for (const auto& it : fn->params) { + DeviceFor(it); + } + ExprVisitor::VisitExpr(fn->body); + visited_.insert(func); + } + + void VisitExpr_(const TupleNode* tn) final { + // We only support tuple with the same of device. + Tuple tup = GetRef(tn); + if (tn->fields.size() > 0) { + auto device = DeviceFor(tup->fields[0]); + for (size_t i = 1; i < tup->fields.size(); i++) { + device = Unify(device, DeviceFor(tup->fields[i])); + } + Unify(device, DeviceFor(tup)); + } + ExprVisitor::VisitExpr_(tn); + } + + void VisitExpr_(const TupleGetItemNode* tn) final { + TupleGetItem item = GetRef(tn); + + Unify(DeviceFor(item), DeviceFor(item->tuple)); + + ExprVisitor::VisitExpr_(tn); + } + + void VisitExpr_(const MatchNode* mn) final { + // For match node, we unify the value and the rhs of each clause + Match m = GetRef(mn); + auto device = Unify(DeviceFor(m), DeviceFor(m->data)); + for (const auto& c : m->clauses) { + device = Unify(device, DeviceFor(c->rhs)); + } + ExprVisitor::VisitExpr_(mn); + } + + void VisitExpr_(const GlobalVarNode* gvn) final { DeviceFor(GetRef(gvn)); } + + void VisitExpr_(const VarNode* vn) { DeviceFor(GetRef(vn)); } + + void VisitExpr_(const ConstantNode* cn) final { DeviceFor(GetRef(cn)); } + + // Return the analysis results. + AnalysisResultMap Results() { + AnalysisResultMap ret; + for (const auto& it : expr_to_device_) { + auto device = Lookup(it.second); + if (device->IsEmptyDomain()) { + ret[it.first] = default_context_; + } else { + ret[it.first] = device->ctx_; + } + } + + return ret; + } + + private: + Expr ExtractClosure(Expr expr) const { + while (expr->IsInstance()) { + Let let = Downcast(expr); + expr = let->value; + if (expr->IsInstance()) { + return expr; + } else { + const auto* cn = expr.as(); + if (cn && cn->op->IsInstance()) { + return cn->op; + } + } + } + return Expr(nullptr); + } + + // Check if an expression is a device copy call. + bool IsDeviceCopy(const Expr& expr) const { + if (!expr->IsInstance()) return false; + + Call call = Downcast(expr); + if (call->op == device_copy_op) return true; + + // Fused function with device copy op as the body + // device copy op is opaque therefore the fused function only has one node. + if (const FunctionNode* fn = call->op.as()) { + if (const CallNode* cn = fn->body.as()) { + return cn->op == device_copy_op; + } + } + + return false; + } + + // Check if a function is a closure. + bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } + + // Check if a function is a currying function. + bool IsCurrying(const Function& func) { + if (const auto* let = func->body.as()) { + return closures_.find(let->var) != closures_.end(); + } + return false; + } + + // Process device copy call node + void UnifyDeviceCopyCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 1U); + + std::vector inps{call->args[0]}; + std::vector outs{GetRef(call)}; + DLDeviceType src_dev_type, dst_dev_type; + const DeviceCopyAttrs* attrs = nullptr; + if (const auto* fn = call->op.as()) { + // device_copy is fused, propagate device to the fused function. + inps.push_back(fn->params[0]); + outs.push_back(call->op); + Expr body = fn->body; + // outs.push_back(fn->body); + CHECK(body->IsInstance() && IsDeviceCopy(body)); + Call call_body = Downcast(body); + attrs = call_body->attrs.as(); + } else { + attrs = call->attrs.as(); + } + CHECK(attrs != nullptr); + src_dev_type = static_cast(attrs->src_dev_type); + dst_dev_type = static_cast(attrs->dst_dev_type); + + // Device copy op only has one input which is now annotated with the + // same device to the source device type of the device copy op. + // The call itself has the same device type to the destination. + UnifyDeviceCopy(inps, outs, src_dev_type, dst_dev_type); + ExprVisitor::VisitExpr_(call); + } + + void UnifyAllocStorageCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 2U); + + // The arguments of alloc storage should be on CPU. + for (int i = 0; i < 2; i++) { + Unify(DeviceFor(call->args[i]), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(call->args[i]); + } + TVMContext ctx; + const auto* attrs = call->attrs.as(); + ctx.device_type = static_cast(attrs->device_type); + ctx.device_id = attrs->device_id; + Unify(DeviceFor(GetRef(call)), DeviceType(ctx)); + } + + void UnifyAllocTensorCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 3U); + + Expr storage = call->args[0]; + Expr shape = call->args[1]; + Unify(DeviceFor(storage), DeviceFor(GetRef(call))); + + // The shape for alloc_tensor should be on CPU. + Unify(DeviceFor(shape), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(shape); + } + + void UnifyShapeFuncCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 3U); + auto shape_func_domain = DeviceType(cpu_ctx_); + + // No need to unify the op of a shape_func as shape_func doesn't + // invoke the op itself. It should be handled by invoke_tvm_op. + // Therefore, we skip call.args[0] here. + Tuple inps = Downcast(call->args[1]); + Tuple outputs = Downcast(call->args[2]); + UnifyCall(GetRef(call), inps->fields, outputs->fields, shape_func_domain); + for (const auto& it : inps->fields) { + ExprVisitor::VisitExpr(it); + } + + for (const auto& it : outputs->fields) { + ExprVisitor::VisitExpr(it); + } + } + + void UnifyInvokeTVMOpCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 3U); + Tuple inps = Downcast(call->args[1]); + Tuple outputs = Downcast(call->args[2]); + UnifyCall(call->args[0], inps->fields, outputs->fields, Bottom()); + ExprVisitor::VisitExpr_(call); + } + + void UnifyShapeOfCall(const CallNode* call) { + // vm shape_of is always on the CPU. + CHECK_EQ(call->args.size(), 1U); + ExprVisitor::VisitExpr(call->args[0]); + Unify(DeviceFor(GetRef(call)), DeviceType(cpu_ctx_)); + } + + void UnifyReshapeTensorCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 2U); + Expr data = call->args[0]; + Expr shape = call->args[1]; + Unify(DeviceFor(GetRef(call)), DeviceFor(data)); + + // The shape field of reshape_tensor is always on the CPU. + Unify(DeviceFor(shape), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(data); + ExprVisitor::VisitExpr(shape); + } + + void UnifyFunctionCall(const CallNode* call) { + auto device = DeviceFor(GetRef(call)); + // Unify the arguments of the caller. + for (const auto& arg : call->args) { + device = Unify(device, DeviceFor(arg)); + ExprVisitor::VisitExpr(arg); + } + + // Unify the parameters of the callee. + if (!call->op->IsInstance()) return; + Function func = Downcast(call->op); + for (const auto& param : func->params) { + device = Unify(device, DeviceFor(param)); + ExprVisitor::VisitExpr(param); + } + + // Unify the function expression and its body + Unify(device, DeviceFor(call->op)); + Unify(device, DeviceFor(func->body)); + + // Step into the callee. It will be skipped if the callee if a primitive + // function + ExprVisitor::VisitExpr(call->op); + } + + // Invoke a global function. + void UnifyGlobalVarCall(const CallNode* call) { + auto device = DeviceFor(GetRef(call)); + CHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; + GlobalVar gv = Downcast(call->op); + auto func = Downcast(mod_->Lookup(gv)); + CHECK_EQ(call->args.size(), func->params.size()) + << "The number of arguments doesn't match the number of parameters of the function."; + + for (size_t i = 0; i < call->args.size(); i++) { + Expr arg = call->args[i]; + Expr param = func->params[i]; + ExprVisitor::VisitExpr(arg); + + // Save the the arg to function mapping for closures as it will + // be invoked/unified later. + CHECK(arg->checked_type().defined()) + << "Type inference is required to run the context analysis passes."; + if (arg->checked_type()->IsInstance()) { + auto it = closures_.find(arg); + if (it != closures_.end()) { + closures_[param] = it->second; + } else { + CHECK(arg->IsInstance()); + closures_[param] = Downcast(arg); + } + } + Unify(DeviceFor(arg), DeviceFor(param)); + } + device = Unify(device, DeviceFor(call->op)); + device = Unify(device, DeviceFor(func)); + device = Unify(device, DeviceFor(func->body)); + + // Step into the callee. We need to skip recursive calls, otherwise, it + // would be a infinite loop. + // + // TODO(@zhiics) This may cause problem for mutual recursive calls as well. + auto cur_func = current_func_; + current_func_ = gv; + if (cur_func->name_hint != gv->name_hint) { + ExprVisitor::VisitExpr(func); + visited_.insert(func); + } + // Exit the frame. + current_func_ = cur_func; + } + + void UnifyVarCall(const CallNode* call) { + // It is a closure when we call a var. + // Unify the corresponding arguement and parameter. + auto device = DeviceFor(GetRef(call)); + auto it = closures_.find(call->op); + CHECK(it != closures_.end()) << "Cannot find var: " << call->op; + auto glb_var = it->second; + CHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; + Function func = Downcast(mod_->Lookup(glb_var)); + // Unify the underlying function for clousre or currying funcitons. + while (IsClosure(func) || IsCurrying(func)) { + device = Unify(device, DeviceFor(func)); + if (IsClosure(func)) { + func = Downcast(func->body); + } else if (IsCurrying(func)) { + Let let = Downcast(func->body); + func = Downcast(mod_->Lookup(closures_[let->var])); + } else { + LOG(FATAL) << "func is expected to be a closure or a currying funciton"; + } + } + + CHECK_EQ(call->args.size(), func->params.size()); + for (size_t i = 0; i < call->args.size(); i++) { + Unify(DeviceFor(call->args[i]), DeviceFor(func->params[i])); + ExprVisitor::VisitExpr(call->args[i]); + } + device = Unify(device, DeviceFor(call->op)); + device = Unify(device, DeviceFor(glb_var)); + device = Unify(device, DeviceFor(func)); + + // Step into the global function. + auto cur_func = current_func_; + current_func_ = glb_var; + if (cur_func->name_hint != glb_var->name_hint) { + ExprVisitor::VisitExpr(func); + visited_.insert(func); + } + current_func_ = cur_func; + } + + private: + /* \brief The cpu context. */ + TVMContext cpu_ctx_; + /* \brief The module that helps context analysis. */ + const IRModule& mod_; + /* \brief The current function that is being analyzed. */ + GlobalVar current_func_; + /* \brief The default device that could be attached to an expression. */ + const TVMContext& default_context_; + /* \brief The IR node to device domain mapping. */ + std::unordered_map + expr_to_device_; + /* \brief The domain map for union-find. */ + std::unordered_map + device_uf_; + /* + * \brief The expr to global var map. It saves the closures/currying that + * will be invoked lazily. + */ + std::unordered_map closures_; + /* \brief Cache the visited functions. */ + std::unordered_set visited_; +}; + +AnalysisResultMap ContextAnalysis(const IRModule& mod, const TVMContext& default_context) { + // TODO(@zhiics) Apply the pass to all functions/entries + auto entry = mod->GetGlobalVar("main"); + auto ca = ContextAnalyzer(mod, entry, default_context); + auto expr = mod->Lookup(entry); + ca.VisitExpr(expr); + return ca.Results(); +} + +// Unpack the device type and deivce fields in TVMContext for PackedFunc calls +// as TVMContext is not in the object system. +PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, + const TVMContext& default_context) { + PackedAnalysisResultMap ret; + auto res = ContextAnalysis(mod, default_context); + for (const auto& it : res) { + Integer dev_ty = static_cast(it.second.device_type); + Integer dev_id = it.second.device_id; + ret.Set(it.first, {dev_ty, dev_id}); + } + + return ret; +} + +} // namespace analysis + +TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis") + .set_body_typed([](IRModule mod, TVMContext default_context) { + return analysis::ContextAnalysisPacked(mod, default_context); + }); + +} // namespace relay +} // namespace tvm From 68cde3629a2befe84e86f6a1a9cf904ca9d37980 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Aug 2020 16:10:04 +0000 Subject: [PATCH 11/21] remove python context analysis --- include/tvm/relay/analysis.h | 11 + python/tvm/relay/analysis/analysis.py | 3 +- python/tvm/relay/analysis/context_analysis.py | 593 ------------------ python/tvm/relay/transform/memory_alloc.py | 69 +- src/relay/analysis/context_analysis.cc | 13 +- src/relay/backend/vm/compiler.cc | 21 +- src/relay/backend/vm/compiler.h | 2 +- src/runtime/vm/serialize_util.h | 3 +- .../relay/test_pass_context_analysis.py | 10 +- 9 files changed, 68 insertions(+), 657 deletions(-) delete mode 100644 python/tvm/relay/analysis/context_analysis.py diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index c65bb41282cf..00da9408408b 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -263,6 +263,17 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod); */ TVM_DLL Map> GetCalibrateOutputMap(const IRModule& mod); +/*! + * \brief Analyze the device context of each IR node in a given relay module. + * + * \param mod The module for analysis. + * \param default_context The default context used by unassigned IR nodes. + * + * \return The mapping between an IR node and its associated context. + */ +TVM_DLL std::unordered_map +ContextAnalysis(const IRModule& mod, const TVMContext& default_context); + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index f15413ce9b5d..a6662e6c5eb3 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -29,7 +29,8 @@ def context_analysis(mod, default_context): - """Analyze the context information of a Relay program. + """Analyze the device context information of each IR node in a Relay + program. Parameters ---------- diff --git a/python/tvm/relay/analysis/context_analysis.py b/python/tvm/relay/analysis/context_analysis.py deleted file mode 100644 index 77565104f60a..000000000000 --- a/python/tvm/relay/analysis/context_analysis.py +++ /dev/null @@ -1,593 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks -""" -A pass for analyzing device attribute of each IR node. -""" -from typing import Optional -from collections import defaultdict - -import tvm -from ..expr_functor import ExprVisitor -from ..function import Function -from .. import ty, op, expr as _expr -from ... import register_func, cpu -from ..._ffi.runtime_ctypes import TVMContext - -def is_closure(func): - """Check if a function is a closure. - - Parameters - ---------- - func : tvm.relay.Function - The input function. - - Returns - ------- - True if the input function is a closure, otherwise false. - """ - return hasattr(func, 'attrs') and \ - hasattr(func.attrs, 'Closure') and int(func.attrs.Closure) == 1 - - -def is_device_copy(call): - """Check if a call node is a device copy call. - - Parameters - ---------- - call : tvm.relay.Call - The call node to be checked. - - Returns - ------- - ret : Boolean - True if the call is a device copy call. Otherwise, false. - """ - if not isinstance(call, _expr.Call): - return False - if call.op == op.op.get("device_copy"): - return True - - if not isinstance(call.op, Function): - return False - return isinstance(call.op.body, _expr.Call) and \ - call.op.body.op == op.op.get("device_copy") - - -class DeviceDomain: - """A class to represent the device of a domain, i.e. a segment of relay - program. - - Parameters - ---------- - ctx : Optional[tvm.runtime.TVMContext] - The device to be assigned to the current domain. It is optional. - """ - def __init__(self, ctx: Optional[TVMContext]): - self.domain = ctx - - def join(self, other: 'DeviceDomain') -> 'DeviceDomain': - """Merge the device of two domains. - - Parameters - ---------- - other : DeviceDomain - The other domain to be merged. - - Returns - ------- - ret : DeviceDomain - The merged domain. An error will be raised if two domain has - conflict, i.e. they have different context. - """ - if self.domain is None and other.domain is None: - return self - elif self.domain is None: - return other - elif other.domain is None: - return self - elif (self.domain.device_type == other.domain.device_type and - self.domain.device_id == other.domain.device_id): - return self - else: - raise Exception("all expressions must have a singular device") - - def __hash__(self): - if self.domain is None: - return id(self) - else: - return hash((self.domain.device_type, self.domain.device_id)) - - def __eq__(self, other): - if self.domain is None and other.domain is None: - return id(self) == id(other) - else: - return self.domain == other.domain - - -def bottom(): - """Create an empty domain. This would usually happen when we enter a new - scope, i.e. Function. - """ - return DeviceDomain(None) - - -def device_type(ctx): - """Create a domain with the given device context. - - Parameters - ---------- - ctx : tvm.runtime.TVMContext - The device context used to construct a domain. - - Returns - ------- - ret : DeviceDomain - The constructed domain. - """ - return DeviceDomain(ctx) - - -class ContextAnalysis(ExprVisitor): - """Compute on which device each sub-expression will execute. A union find - algorithm is used to assign and merge the context domains. - - Parameters - ---------- - mod : tvm.IRModule - The module that helps context analysis. - - current_func : tvm.relay.GlobalVar - The current function that is being analyzed. - - fallback_device : tvm.runtime.TVMContext - The default device that could be attached to an expression. - """ - def __init__(self, mod, current_func, fallback_device): - super().__init__() - self.expr_to_device = defaultdict(bottom) - self.device_uf = {} - self.mod = mod - self.closures = {} - self.current_func = current_func - self.fallback_device = fallback_device - self.visited = {} - - def lookup(self, device): - """Find the root domain of a given device domain. - - Parameters - ---------- - device : DeviceDomain - The domain that is used to query the root domain. - - Returns - ------- - ret : DeviceDomain - The root domain. - """ - while device in self.device_uf and device.domain is None: - # Path compression - if self.device_uf[device] in self.device_uf: - self.device_uf[device] = self.device_uf[self.device_uf[device]] - device = self.device_uf[device] - return device - - def unify(self, lhs, rhs): - """Unify the device context of two domains. - - Parameters - ---------- - lhs : DeviceDomain - The lhs domain to unify. - - rhs : DeviceDomain - The rhs domain to unify. - - Returns - ------- - ret : DeviceDomain - The unified domain. - """ - lhs = self.lookup(lhs) - rhs = self.lookup(rhs) - unified_device = lhs.join(rhs) - if not lhs == unified_device: - self.device_uf[lhs] = unified_device - if not rhs == unified_device: - self.device_uf[rhs] = unified_device - return unified_device - - def unify_expr(self, lhs, rhs): - """Compute the device type of both expressions and unify them. - - Parameters - ---------- - lhs : tvm.relay.Expr - The lhs expression to unify. - - rhs : tvm.relay.Expr - The rhs expression to unify. - - Returns - ------- - ret : DeviceDomain - The unified domain. - """ - return self.unify(self.device_for(lhs), self.device_for(rhs)) - - def device_for(self, expr): - """Find the domain that contains the given expr. - - Parameters - ---------- - expr : tvm.relay.Expr - The expression used to lookup a domain. - - Returns - ------- - ret : DeviceDomain - The domain that contains the expression. - """ - return self.lookup(self.expr_to_device[expr]) - - def device_copy(self, inps, outputs, src_dev_type, dst_dev_type): - """Unify the device context for device copy node. Device copy node is - the only node that carries information in the input program. The device - attribute of other nodes are propagated from it. - - Parameters - ---------- - inps : List[tvm.relay.Expr] - The input expression to the device copy node. The device type of - the input should be the same as the source device type of the - copy node. - - outputs : List[tvm.relay.Expr] - The output expression of the device copy node. The device type of - the output should be the same as the destination device type of the - copy node. - - src_dev_type : int - The source device type of the copy node. - - dst_dev_type : int - The destination device type of the copy node. - """ - src_dev_type = device_type(TVMContext(src_dev_type, 0)) - for inp in inps: - self.unify(self.device_for(inp), src_dev_type) - - dst_dev_type = device_type(TVMContext(dst_dev_type, 0)) - for output in outputs: - self.unify(self.device_for(output), dst_dev_type) - - def unify_call(self, call_op, inputs, outputs, device=None): - """Unify the domain of inputs and outputs of a relay Call. - - Parameters - ---------- - op : tvm.relay.Expr - The op of a call node. - - inputs : List[tvm.relay.Expr] - The inputs of the call. - - outputs : List[tvm.relay.Expr] - The outputs of the call. - - Returns - ------- - The unified domain. - - Note - ---- - For most call nodes, the op, inputs, and outputs should all be in the - same domain, i.e. having the same context. However, device_copy call node - needs to be handled differently as it copies data from one device to - another. - """ - device = device if device else bottom() - for arg in inputs: - device = self.unify(device, self.device_for(arg)) - - device = self.unify(device, self.device_for(call_op)) - - for out in outputs: - device = self.unify(device, self.device_for(out)) - - return device - - def visit_call(self, call): - if is_device_copy(call): - inps = [call.args[0]] - outs = [call] - if isinstance(call.op, Function): - # device_copy is fused, propagate device to the fused function - inps.append(call.op.params[0]) - outs.append(call.op) - body = call.op.body - assert isinstance(body, _expr.Call) and is_device_copy(body) - # outs.append(call.op.body) - src_dev_type = call.op.body.attrs.src_dev_type - dst_dev_type = call.op.body.attrs.dst_dev_type - else: - src_dev_type = call.attrs.src_dev_type - dst_dev_type = call.attrs.dst_dev_type - - # Device copy op only has one input which is now annotated with the - # same device to the source device type of the device copy op. - # The call itself has the same device type to the destination. - self.device_copy(inps, outs, src_dev_type, dst_dev_type) - super().visit_call(call) - elif call.op == op.op.get("memory.alloc_storage"): - # The arguments should be on CPU. - for arg in (call.args[0], call.args[1]): - self.unify(self.device_for(arg), device_type(cpu(0))) - self.visit(arg) - call_dev = device_type(TVMContext(call.attrs.device_type, - call.attrs.device_id)) - self.unify(self.device_for(call), call_dev) - elif call.op == op.op.get("memory.alloc_tensor"): - storage, shape = call.args[0], call.args[1] - self.unify(self.device_for(storage), self.device_for(call)) - # The shape for alloc_tensor should be on CPU. - self.unify(self.device_for(shape), device_type(cpu(0))) - self.visit(shape) - elif call.op == op.op.get("vm.shape_func"): - shape_func_domain = device_type(cpu(0)) - # No need to union the op of a shape_func as shape_func doesn't - # invoke the op itself. It should be handled by invoke_tvm_op. - # Therefore, we skip call.args[0] here. - self.unify_call(call, call.args[1].fields, - call.args[2].fields, shape_func_domain) - for arg in call.args[1]: - self.visit(arg) - for arg in call.args[2]: - self.visit(arg) - elif call.op == op.op.get("vm.invoke_tvm_op"): - device = self.unify_call(call.args[0], call.args[1].fields, - call.args[2].fields) - self.unify(self.device_for(call), device) - super().visit_call(call) - elif call.op == op.op.get("vm.shape_of"): - # vm shape_of is always on the CPU. - self.visit(call.args[0]) - self.unify(self.device_for(call), device_type(cpu(0))) - elif call.op == op.op.get("vm.reshape_tensor"): - data, shape = call.args - self.unify(self.device_for(call), self.device_for(data)) - # The shape field of reshape_tensor is always on the CPU. - self.unify(self.device_for(shape), device_type(cpu(0))) - self.visit(data) - self.visit(shape) - elif isinstance(call.op, Function): - device = self.device_for(call) - for arg in call.args: - device = self.unify(device, self.device_for(arg)) - self.visit(arg) - - for param in call.op.params: - self.visit(param) - device = self.unify(device, self.device_for(param)) - - self.unify(device, self.device_for(call.op)) - self.unify(device, self.device_for(call.op.body)) - # self.visit(call.op) - elif isinstance(call.op, _expr.GlobalVar): - device = self.device_for(call) - assert self.mod, "Cannot analyze context on a globalvar without module" - func = self.mod[call.op] - - assert len(call.args) == len(func.params) - - for arg, param in zip(call.args, func.params): - self.visit(arg) - # Save the the arg to function mapping for closures as it will - # be invoked/unified later. - if isinstance(arg.checked_type, ty.FuncType): - if arg in self.closures: - self.closures[param] = self.closures[arg] - else: - assert isinstance(arg, _expr.GlobalVar) - self.closures[param] = arg - self.unify(self.device_for(arg), self.device_for(param)) - - device = self.unify(device, self.device_for(call.op)) - device = self.unify(device, self.device_for(func)) - device = self.unify(device, self.device_for(func.body)) - # Step into the callee. We need to skip recursive calls, otherwise, it - # would be a infinite loop, so does mutual recursive calls - cur_func = self.current_func - self.current_func = call.op - if cur_func.name_hint != call.op.name_hint: - self.visit(func) - self.visited[func] = device - self.current_func = cur_func - elif isinstance(call.op, _expr.Var): - # It is a closure when we call a var - # Unify the corresponding arguement and parameter - device = self.device_for(call) - assert call.op in self.closures, f"Cannot find {call.op}" - glb_var = self.closures[call.op] - assert self.mod, "Cannot analyze context on a globalvar without module" - func = self.mod[glb_var] - # Unify the underlying function for clousre or currying funcitons. - while is_closure(func) or (isinstance(func.body, _expr.Let) and - func.body.var in self.closures): - device = self.unify(device, self.device_for(func)) - if is_closure(func): - func = func.body - elif (isinstance(func.body, _expr.Let) and func.body.var in self.closures): - func = self.mod[self.closures[func.body.var]] - - assert isinstance(func, Function) - assert len(call.args) == len(func.params) - - for dev_arg, dev_param in zip(call.args, func.params): - self.visit(dev_arg) - self.unify(self.device_for(dev_arg), self.device_for(dev_param)) - - device = self.unify(device, self.device_for(call.op)) - device = self.unify(device, self.device_for(glb_var)) - device = self.unify(device, self.device_for(func)) - cur_func = self.current_func - # Step into the closure. - self.current_func = glb_var - if not tvm.ir.structural_equal(cur_func, glb_var): - self.visit(func) - self.visited[func] = device - self.current_func = cur_func - else: - self.unify_call(call, call.args, [call]) - super().visit_call(call) - - def _extract_closure(self, expr): - while isinstance(expr, _expr.Let): - expr = expr.value - if isinstance(expr, _expr.GlobalVar): - return expr - elif isinstance(expr, _expr.Call) and isinstance(expr.op, - _expr.GlobalVar): - return expr.op - return None - - def visit_let(self, let): - while isinstance(let, _expr.Let): - # Save currying/closures since they will be invoked later - if isinstance(let.value.checked_type, ty.FuncType): - gv = self._extract_closure(let) - assert gv - self.closures[let.var] = gv - - self.unify(self.device_for(let.var), self.device_for(let.value)) - self.unify_expr(let, let.body) - self.visit(let.value) - let = let.body - - self.visit(let) - - def is_primitive(self, func): - return hasattr(func, 'attrs') and hasattr(func.attrs, 'Primitive') \ - and int(func.attrs.Primitive) == 1 - - def visit_function(self, f): - if self.is_primitive(f) or (f in self.visited and self.visited[f].domain): - return - - device = self.unify(self.device_for(f), self.device_for(f.body)) - for x in f.params: - self.device_for(x) - self.visit(f.body) - self.visited[f] = device - - def visit_tuple(self, tup): - # We only support tuple with the same of device. - if not tup: - return - device = self.device_for(tup[0]) - for i in range(1, len(tup)): - device = self.unify(device, self.device_for(tup[i])) - self.unify(device, self.device_for(tup)) - super().visit_tuple(tup) - - def visit_tuple_getitem(self, t): - value = t.tuple_value - if isinstance(t.tuple_value, _expr.Tuple): - self.unify(self.device_for(t), self.device_for(value)) - value = t.tuple_value[t.index] - self.unify(self.device_for(t), self.device_for(value)) - super().visit_tuple_getitem(t) - - def visit_match(self, m): - # For match node, we unify the value and the rhs of each clause - device = self.unify(self.device_for(m), self.device_for(m.data)) - for c in m.clauses: - device = self.unify(device, self.device_for(c.rhs)) - super().visit_match(m) - - def visit_global_var(self, gv): - self.device_for(gv) - - def visit_var(self, var): - self.device_for(var) - - def visit_constant(self, const): - self.device_for(const) - - def results(self): - """Return the analysis result. - - Returns - ------- - ret : Dict[tvm.relay.Expr, DeviceDomain] - The dictionary mapping each expression to a device context. - """ - results = {} - for exp in self.expr_to_device: - device = self.lookup(self.expr_to_device[exp]) - if device.domain is None: - results[exp] = self.fallback_device - else: - results[exp] = device.domain - - return results - - -def mk_analysis_annotator(results): - """Pretty print the annotated relay program with device info""" - def _annotator(exp): - if exp in results: - return f"<{results[exp]}>" - else: - return "" - - return _annotator - - -def context_analysis(mod, fallback_device): - """Perform device context analysis on a given relay program. This requires - that the program has already been annotated and rewritten by replacing on - device annotations with device copy nodes. - - Parameters - ---------- - mod : tvm.IRModule - The IRModule for analysis - - fallback_device : tvm.runtime.TVMContext - The default device context - - Returns - ------- - ret : Dict[tvm.relay.Expr, [int]] - The mapping of each expression to the device context that is - represented in a list form as TVMContext is not a runtime object. - """ - assert isinstance(mod, tvm.IRModule) - # TODO(@zhiics) Apply the pass to all functions/entries - entry = mod.get_global_var("main") - ca = ContextAnalysis(mod, entry, fallback_device) - expr = mod[entry] - ca.visit(expr) - ret = defaultdict(list) - for key, val in ca.results().items(): - ret[key] = [val.device_type, val.device_id] - return ret - - -# register_func("relay.analysis.ContextAnalysis", context_analysis) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index cc610ce25ce2..f8111bbfebae 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,7 +19,6 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -import logging from tvm.ir.transform import PassContext, module_pass from tvm import nd, container @@ -33,12 +32,9 @@ from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu from ..op.memory import alloc_storage -from ..analysis import context_analysis -from ..analysis.context_analysis import mk_analysis_annotator +from ..analysis import context_analysis as _context_analysis from ..._ffi.runtime_ctypes import TVMContext -# logging.basicConfig(level=logging.DEBUG) - def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) @@ -102,14 +98,19 @@ def __init__(self, target_host, context_analysis): self.cached_var = {} super().__init__() - def get_context(self, expr): - assert expr in self.context_analysis, expr.astext(False) - return self.context_analysis[expr] - - def device_copy(self, scope, inp, src_ctx, dst_ctx, idx): + def get_context(self, exp): + """Get the context of a given expression""" + assert exp in self.context_analysis, exp.astext(False) + val = self.context_analysis[exp] + # val[0], val[1] are device_type and device_id, respectively. + # We don't need to unpack after porting this pass to C++. + assert len(val) == 2 + return TVMContext(val[0].value, val[1].value) + + def device_copy(self, inp, src_ctx, dst_ctx): + """Insert a device copy node.""" copy = self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) - copy_out = scope.let("copy_out_{0}".format(idx), copy) - return copy_out + return copy def current_scope(self): return self.scopes[-1] @@ -217,7 +218,7 @@ def emit_shape_func(self, scope, func, new_args): # pass may be needed to decide/remove this copy. if not self.skip_copy_input(subexp) and \ ctx.device_type != cpu_ctx.device_type: - subexp = self.device_copy(scope, subexp, ctx, cpu_ctx, j) + subexp = self.device_copy(subexp, ctx, cpu_ctx) subexp = scope.let("in_arg_{0}".format(input_pos + j), subexp) sh_of = self.visit(self.shape_of(subexp)) @@ -229,7 +230,7 @@ def emit_shape_func(self, scope, func, new_args): elif state == 1: new_arg = self.visit(arg) if ctx.device_type != cpu_ctx.device_type: - new_arg = self.device_copy(scope, new_arg, ctx, cpu_ctx, i) + new_arg = self.device_copy(new_arg, ctx, cpu_ctx) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos), new_arg)) input_pos += 1 @@ -324,9 +325,9 @@ def visit_call(self, call): attr = call.op.body.attrs else: attr = call.attr - return op.tensor.device_copy(new_args[0], - TVMContext(attr.src_dev_type, 0), - TVMContext(attr.dst_dev_type, 0)) + return self.device_copy(new_args[0], + TVMContext(attr.src_dev_type, 0), + TVMContext(attr.dst_dev_type, 0)) if self.is_dynamic(ret_type): # Handle dynamic case. return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) @@ -346,6 +347,20 @@ def visit_call(self, call): return super().visit_call(call) +def mk_analysis_annotator(results): + """Pretty print the annotated relay program with device info""" + def _annotator(exp): + if exp in results: + val = results[exp] + assert len(val) == 2 + ctx = TVMContext(val[0].value, val[1].value) + return f"<{ctx}>" + else: + return "" + + return _annotator + + @module_pass(opt_level=0) class ManifestAlloc: """The explicit pass wrapper around ManifestAlloc.""" @@ -358,10 +373,6 @@ def transform_module(self, mod, _): # can we have def pass_init? mod.import_from_std("core.rly") - # We use logger here to help debug. - logging.debug("-----BEFORE CONTEXT ANALYSIS-----") - logging.debug(mod.astext(False)) - assert isinstance(self.targets, (dict, container.Map)) if len(self.targets) > 1: pass_ctx = PassContext.current() @@ -369,25 +380,21 @@ def transform_module(self, mod, _): fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) else: fallback_ctx = cpu(0) - ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) + ca = _context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) else: if isinstance(self.targets, dict): dev = list(self.targets.keys())[0] else: dev, _ = self.targets.items()[0] - ca = context_analysis(mod, nd.context(dev.value)) + ca = _context_analysis(mod, nd.context(dev.value)) - # TODO(zhiics) This is not needed after we port the pass to C++ - ca_res = {} - for key, val in ca.items(): - ca_res[key] = TVMContext(val[0].value, val[1].value) + # The following code can be used for debugging the module after + # annotation. + # print(mod.astext(show_meta_data=False, annotate=mk_analysis_annotator(ca))) - logging.debug("-----AFTER CONTEXT ANALYSIS-----") - logging.debug(mod.astext(show_meta_data=False, - annotate=mk_analysis_annotator(ca_res))) gv_funcs = mod.functions for gv, f in gv_funcs.items(): - ea = ManifestAllocPass(self.target_host, ca_res) + ea = ManifestAllocPass(self.target_host, ca) f = ea.visit(f) mod.update_func(gv, f) return mod diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 5c664bb47db9..24e8b8075885 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -37,12 +37,13 @@ namespace tvm { namespace relay { -namespace analysis { using PackedAnalysisResultMap = Map>; using AnalysisResultMap = std::unordered_map; +namespace analysis { + // Cache ops static const Op& device_copy_op = Op::Get("device_copy"); static const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); @@ -651,16 +652,18 @@ class ContextAnalyzer : public ExprVisitor { std::unordered_set visited_; }; +} // namespace analysis + AnalysisResultMap ContextAnalysis(const IRModule& mod, const TVMContext& default_context) { // TODO(@zhiics) Apply the pass to all functions/entries auto entry = mod->GetGlobalVar("main"); - auto ca = ContextAnalyzer(mod, entry, default_context); + auto ca = analysis::ContextAnalyzer(mod, entry, default_context); auto expr = mod->Lookup(entry); ca.VisitExpr(expr); return ca.Results(); } -// Unpack the device type and deivce fields in TVMContext for PackedFunc calls +// Unpack the device type and deivce id fields in TVMContext for PackedFunc calls // as TVMContext is not in the object system. PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, const TVMContext& default_context) { @@ -675,11 +678,9 @@ PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, return ret; } -} // namespace analysis - TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis") .set_body_typed([](IRModule mod, TVMContext default_context) { - return analysis::ContextAnalysisPacked(mod, default_context); + return ContextAnalysisPacked(mod, default_context); }); } // namespace relay diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index a71fc8e96aae..1a90accbd6e3 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -63,22 +64,6 @@ Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) { return (*f)(target_host, targets); } -vm::ExprDeviceMap ContextAnalysis(IRModule mod, TVMContext default_device) { - auto f = tvm::runtime::Registry::Get("relay.analysis.ContextAnalysis"); - CHECK(f != nullptr) << "could not load context analysis pass"; - Map> m = (*f)(mod, default_device); - vm::ExprDeviceMap ret; - for (const auto& it : m) { - TVMContext ctx; - Array ints = it.second; - CHECK_EQ(ints.size(), 2U); - ctx.device_type = static_cast(ints[0]->value); - ctx.device_id = static_cast(ints[1]->value); - ret[it.first] = ctx; - } - return ret; -} - Pass MemoryPlan() { auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); CHECK(f != nullptr) << "unable to load the memory planning pass"; @@ -1182,13 +1167,13 @@ ExprDeviceMap VMCompiler::AnalyzeContext() const { int fallback_dev = GetFallbackDevice(); default_device.device_type = static_cast(fallback_dev); default_device.device_id = 0; - expr_device_map = transform::ContextAnalysis(context_.module, default_device); + expr_device_map = ContextAnalysis(context_.module, default_device); } else { const auto& tgt = targets_.begin(); default_device.device_type = static_cast((*tgt).first->value); if (default_device.device_type != kDLCPU) { default_device.device_id = 0; - expr_device_map = transform::ContextAnalysis(context_.module, default_device); + expr_device_map = ContextAnalysis(context_.module, default_device); } } return expr_device_map; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index b2c8016cc7d3..19924ab38358 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,7 @@ using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; using TargetsMap = Map; -using ExprDeviceMap = std::unordered_map; +using ExprDeviceMap = std::unordered_map; struct VMCompilerContext { // The module context for the compilation diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index 25bdf09e0b5a..d17256d6a079 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -63,7 +63,8 @@ struct VMFunctionSerializer { VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, - const std::vector& params, const std::vector& params_device_type) + const std::vector& params, + const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py index 0788a1bdd7c1..535e7a25ba70 100644 --- a/tests/python/relay/test_pass_context_analysis.py +++ b/tests/python/relay/test_pass_context_analysis.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from tvm.relay import expr as _expr, transform -from tvm.relay.analysis.context_analysis import ContextAnalysis, mk_analysis_annotator +from tvm.relay.analysis import context_analysis def test_device_copy(): @@ -34,7 +34,7 @@ def test_device_copy(): out = copy + relay.const(np.random.rand(2, 3)) glb_var = relay.GlobalVar("main") mod[glb_var] = relay.Function([x], out) - ca = ContextAnalysis(mod, glb_var, tvm.cpu()) + ca = context_analysis(mod, tvm.cpu()) ca.visit(mod[glb_var]) ca_res = ca.results() @@ -78,10 +78,8 @@ def test_dynamic_input(): y = nested_data_tuple[1] * data_tuple[1] + data1 mod["main"] = relay.Function([data0, data1], y) compiler = relay.vm.VMCompiler() - entry = mod.get_global_var("main") - ca = ContextAnalysis(mod, entry, tvm.cpu()) - ca.visit(mod[entry]) - ca_res = ca.results() + # mod, _ = compiler.optimize(mod, target="cuda") + # ca = context_analysis(mod, tvm.cpu()) if __name__ == "__main__": pass From 8ba6efdafbef6a904fa33a84d2da94241bd826b1 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Aug 2020 21:43:36 +0000 Subject: [PATCH 12/21] add tests --- python/tvm/relay/transform/memory_alloc.py | 16 +- python/tvm/relay/transform/memory_plan.py | 2 + src/relay/analysis/context_analysis.cc | 5 + .../relay/test_pass_context_analysis.py | 166 +++++++++++++++--- 4 files changed, 151 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index f8111bbfebae..42cd84c0449a 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -109,8 +109,7 @@ def get_context(self, exp): def device_copy(self, inp, src_ctx, dst_ctx): """Insert a device copy node.""" - copy = self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) - return copy + return self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) def current_scope(self): return self.scopes[-1] @@ -208,19 +207,6 @@ def emit_shape_func(self, scope, func, new_args): # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): - # Note vm.shape_of is always executed on CPU. The input may - # need to have a copy as it is likely passed from other - # callers and the argument was on other devices. Constants - # can leave on CPU. - # - # However, this may cause an unnecessary copy when - # all dynamic inputs can be assigned to CPU. An additional - # pass may be needed to decide/remove this copy. - if not self.skip_copy_input(subexp) and \ - ctx.device_type != cpu_ctx.device_type: - subexp = self.device_copy(subexp, ctx, cpu_ctx) - subexp = scope.let("in_arg_{0}".format(input_pos + j), - subexp) sh_of = self.visit(self.shape_of(subexp)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index a50f49279770..248a79ba44de 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -281,6 +281,8 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): self.enter_scope() dynamic_regions.append(lhs) else: + # A new scope is created when entering a new region with different + # device context. region = self.current_region(dtype) if region.ctx and region.ctx.device_type != ctx.device_type: self.enter_scope() diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 24e8b8075885..d9e229761ebb 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -500,6 +500,11 @@ class ContextAnalyzer : public ExprVisitor { // vm shape_of is always on the CPU. CHECK_EQ(call->args.size(), 1U); ExprVisitor::VisitExpr(call->args[0]); + // Note we don't unify the input of a shape_of with the cpu domain. This is + // because vm.shape_of has a native instruction to compute the shape of + // a tensor regardless its device type. + // Instead, the device type of the input is left for its other consumers to + // unify or it will fallback to the default context. Unify(DeviceFor(GetRef(call)), DeviceType(cpu_ctx_)); } diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py index 535e7a25ba70..aadb640ac251 100644 --- a/tests/python/relay/test_pass_context_analysis.py +++ b/tests/python/relay/test_pass_context_analysis.py @@ -17,6 +17,7 @@ # pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks import numpy as np +import pytest import tvm from tvm import relay @@ -35,51 +36,170 @@ def test_device_copy(): glb_var = relay.GlobalVar("main") mod[glb_var] = relay.Function([x], out) ca = context_analysis(mod, tvm.cpu()) - ca.visit(mod[glb_var]) - ca_res = ca.results() - for expr, dev in ca_res.items(): + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + for expr, dev in ca.items(): if isinstance(expr, _expr.Call): - assert dev.device_type == tvm.gpu().device_type + assert dev[0].value == gpu_dev elif isinstance(expr, _expr.Var): - assert dev.device_type == tvm.cpu().device_type + assert dev[0].value == cpu_dev elif isinstance(expr, _expr.Constant): - assert dev.device_type == tvm.gpu().device_type + assert dev[0].value == gpu_dev def test_shape_func(): - pass + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.op.vm.shape_of(x) + z = relay.nn.relu(y) + p0 = relay.var("p0", shape=data_shape) + fn = relay.Function([p0], z) + out = relay.var("out", shape=(1,), dtype="int64") + ins = relay.Tuple([y]) + outs = relay.Tuple([out]) + is_inputs = [False] + shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs) + mod["main"] = relay.Function([x, out], shape_func) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + # The output of shape func should be on cpu. + assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev + # shape func is the body and it should be on cpu + assert main.body in ca and ca[main.body][0].value == cpu_dev def test_vm_shape_of(): - pass + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.op.vm.shape_of(x) + mod["main"] = relay.Function([x], y) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + assert main.body in ca and ca[main.body][0].value == cpu_dev def test_alloc_storage(): - pass + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + mod.import_from_std("core.rly") + size = relay.Var("size", relay.scalar_type("int64")) + alignment = relay.Var("alignment", relay.scalar_type("int64")) + # allocate a chunk on of memory on gpu. + sto = relay.op.memory.alloc_storage(size, alignment, tvm.gpu()) + mod["main"] = relay.Function([size, alignment], sto) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Inputs are unified with alloc storage inputs which are on cpu + assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev + assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev + + assert isinstance(body, relay.Call) and len(body.args) == 2 + # size of alloc_storage is on cpu + assert body.args[0] in ca and ca[body.args[0]][0].value == cpu_dev + # alignment of alloc_storage is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # alloc_storage is on gpu as specified + assert body in ca and ca[body][0].value == gpu_dev def test_alloc_tensor(): - pass + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + mod.import_from_std("core.rly") + sto_type = relay.TypeCall(mod.get_global_type_var("Storage"), []) + sto = relay.Var("x", sto_type) + sh = relay.const(np.array([3, 2]), dtype="int64") + at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) + mod["main"] = relay.Function([sto], at) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Input of the function falls back to the default device gpu + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + + assert isinstance(body, relay.Call) and len(body.args) == 3 + # storage of alloc_tensor falls back to the default device gpu + assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev + # shape of alloc_tensor is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # alloc_tensor keeps the same device context as storage which is is on gpu + assert body in ca and ca[body][0].value == gpu_dev + + +def test_vm_reshape_tensor(): + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + return + + x = relay.var("x", shape=(2, 8), dtype="float32") + shape = relay.const([-1, 4, 2], dtype="int64") + y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2]) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], y) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Input of the function falls back to the default device gpu + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + + # dats of reshape_tensor falls back to the default device gpu + assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev + # shape of reshape_tensor is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # reshape_tensor sits on the same device as the data + assert body in ca and ca[body][0].value == gpu_dev def test_dynamic_input(): if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: return + mod = tvm.IRModule() - dtype = "float32" - data_shape = (relay.Any(), 4) - tensor_type = relay.TensorType(data_shape, dtype) - tuple_type = relay.TupleType([tensor_type, tensor_type]) - data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type])) - data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype) - data_tuple = relay.expr.TupleWrapper(data0, 2) - nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2) - y = nested_data_tuple[1] * data_tuple[1] + data1 - mod["main"] = relay.Function([data0, data1], y) + data_shape = (relay.Any(), relay.Any()) + x0 = relay.var("x0", shape=data_shape) + x1 = relay.var("x1", shape=data_shape) + mod["main"] = relay.Function([x0, x1], x0 + x1) + compiler = relay.vm.VMCompiler() - # mod, _ = compiler.optimize(mod, target="cuda") - # ca = context_analysis(mod, tvm.cpu()) + mod, _ = compiler.optimize(mod, target="cuda") + ca = context_analysis(mod, tvm.cpu()) + main = mod["main"] + + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev + assert main.body in ca and ca[main.body][0].value == gpu_dev + if __name__ == "__main__": - pass + pytest.main([__file__]) From 62fce976493d08257576bdaa0156ebd50c7ca381 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Aug 2020 22:14:18 +0000 Subject: [PATCH 13/21] clean --- include/tvm/runtime/vm/bytecode.h | 2 +- python/tvm/relay/transform/memory_alloc.py | 16 ++-------------- src/relay/analysis/context_analysis.cc | 5 +++++ src/relay/backend/vm/compiler.cc | 4 ++-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index cb9a59a9ab93..edcbd881e074 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -378,7 +378,7 @@ struct Instruction { * \param src_device_type The device type of the tensor for the source register. * \param dst_device_type The device type of the tensor ofr the destination register. * \param dst The destination register to store the copied tensor. - * \return The reshape tensor instruction. + * \return The device copy instruction. */ static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, RegName dst); diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 42cd84c0449a..3c7d86459689 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -95,7 +95,6 @@ def __init__(self, target_host, context_analysis): self.default_context = cpu(0) self.compute_dtype = "int64" self.context_analysis = context_analysis - self.cached_var = {} super().__init__() def get_context(self, exp): @@ -173,7 +172,6 @@ def visit_let(self, let): self.scopes.append(scope) while isinstance(let, expr.Let): - self.cached_var[let.var] = let.value new_val = self.visit(let.value) scope.let(let.var, new_val) let = let.body @@ -184,13 +182,6 @@ def visit_let(self, let): return scope.get() - def skip_copy_input(self, var): - """Check if device copy for the input should be skipped. We currently - skip copying it when the input is a constant. - """ - return (var in self.cached_var and isinstance(self.cached_var[var], - expr.Constant)) - def emit_shape_func(self, scope, func, new_args): """Insert the shape function given a primitive function.""" shape_func_ins = [] @@ -203,7 +194,6 @@ def emit_shape_func(self, scope, func, new_args): cpu_ctx = nd.cpu(0) for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) - ctx = self.get_context(arg) # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): @@ -215,6 +205,7 @@ def emit_shape_func(self, scope, func, new_args): # Pass Inputs elif state == 1: new_arg = self.visit(arg) + ctx = self.get_context(arg) if ctx.device_type != cpu_ctx.device_type: new_arg = self.device_copy(new_arg, ctx, cpu_ctx) shape_func_ins.append( @@ -275,14 +266,11 @@ def emit_reshape_tensor(self, scope, func, new_args, ret_type): if self.is_dynamic(ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) shape_expr = out_shapes[0] - inp = new_args[0] - ret = self.reshape_tensor(inp, shape_expr, ret_type.shape) - return ret else: # constant output shape shape = [int(dim) for dim in ret_type.shape] shape_expr = expr.const(shape, dtype=self.compute_dtype) - return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) + return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) def is_dynamic(self, ret_type): is_dynamic = ty.is_dynamic(ret_type) diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index d9e229761ebb..3f749d6e8dce 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -443,6 +443,7 @@ class ContextAnalyzer : public ExprVisitor { } void UnifyAllocStorageCall(const CallNode* call) { + // [size, alignment] CHECK_EQ(call->args.size(), 2U); // The arguments of alloc storage should be on CPU. @@ -458,6 +459,7 @@ class ContextAnalyzer : public ExprVisitor { } void UnifyAllocTensorCall(const CallNode* call) { + // [storage, offset, shape] CHECK_EQ(call->args.size(), 3U); Expr storage = call->args[0]; @@ -470,6 +472,7 @@ class ContextAnalyzer : public ExprVisitor { } void UnifyShapeFuncCall(const CallNode* call) { + // [func, inputs, outputs] CHECK_EQ(call->args.size(), 3U); auto shape_func_domain = DeviceType(cpu_ctx_); @@ -489,6 +492,7 @@ class ContextAnalyzer : public ExprVisitor { } void UnifyInvokeTVMOpCall(const CallNode* call) { + // [op, inputs, outputs] CHECK_EQ(call->args.size(), 3U); Tuple inps = Downcast(call->args[1]); Tuple outputs = Downcast(call->args[2]); @@ -509,6 +513,7 @@ class ContextAnalyzer : public ExprVisitor { } void UnifyReshapeTensorCall(const CallNode* call) { + // [data, shape] CHECK_EQ(call->args.size(), 2U); Expr data = call->args[0]; Expr shape = call->args[1]; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 1a90accbd6e3..b436514b59b7 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -640,9 +640,9 @@ class VMFunctionCompiler : ExprFunctor { auto dtype = alloc_attrs->dtype; Index device_type; - // There is bug if all expression are annotated with the device that - // other than the first one in the target list. if (expr_device_map_.empty()) { + // TODO(zhiics) There is bug if all expressions are annotated with the device + // that is different the first one in the target list. auto& kv = *(targets_.begin()); device_type = kv.first; } else { From e7c235a6ae201a9f9979d0682592c752202dfe5b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Aug 2020 23:35:02 +0000 Subject: [PATCH 14/21] lint --- python/tvm/ir/module.py | 6 +++--- python/tvm/relay/analysis/__init__.py | 2 -- python/tvm/relay/transform/memory_alloc.py | 1 + 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 851e40bd0ad2..2f6fd2069460 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -118,19 +118,19 @@ def update(self, other): other = Module(other) return _ffi_api.Module_Update(self, other) - def update_func(self, gv, func): + def update_func(self, var, func): """Update the function corresponding to a global variable in the module. Parameters ---------- - gv: GlobalVar + var: GlobalVar The global variable. func: tvm.relay.Function The function to be inserted. """ - return _ffi_api.Module_UpdateFunction(self, gv, func) + return _ffi_api.Module_UpdateFunction(self, var, func) def get_global_var(self, name): """Get a global variable in the function by name. diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index 0c065ef07a6e..e5b21cb107f5 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -29,5 +29,3 @@ # Feature from . import feature from . import sparse_dense - -from . import context_analysis diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 3c7d86459689..9e104021e5fe 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -343,6 +343,7 @@ def __init__(self, target_host, targets): self.targets = targets def transform_module(self, mod, _): + """Invokes the pass""" # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") From b64a5f1232f2fc8b74568efd1f8a21225f8bdca5 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 26 Aug 2020 00:20:59 +0000 Subject: [PATCH 15/21] fix --- src/relay/analysis/context_analysis.cc | 2 +- src/runtime/vm/executable.cc | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 3f749d6e8dce..3740aa2f3afb 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -156,7 +156,7 @@ class ContextAnalyzer : public ExprVisitor { while (device_uf_.count(device) && device != device_uf_[device]) { // Path compression if (device_uf_.count(device_uf_[device])) { - device_uf_[device] == device_uf_[device_uf_[device]]; + device_uf_[device] = device_uf_[device_uf_[device]]; } device = device_uf_[device]; } diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c95b739a1d44..cad145d9c707 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -23,7 +23,6 @@ */ #include -#include #include #include #include From 23d25fb8e268d4ecdce83cdb4901a6950ad52912 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 26 Aug 2020 16:16:56 +0000 Subject: [PATCH 16/21] enable gpu test for dynamic namespace --- python/tvm/relay/analysis/analysis.py | 2 +- python/tvm/relay/transform/memory_alloc.py | 6 +++--- tests/python/relay/dyn/test_dynamic_op_level10.py | 11 +++++------ tests/python/relay/dyn/test_dynamic_op_level2.py | 1 - tests/python/relay/dyn/test_dynamic_op_level3.py | 2 -- tests/python/relay/dyn/test_dynamic_op_level5.py | 1 - 6 files changed, 9 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index a6662e6c5eb3..d417c2b39b08 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -34,7 +34,7 @@ def context_analysis(mod, default_context): Parameters ---------- - expr : tvm.IRModule + mod : tvm.IRModule The input module. default_context : tvm.runtime.TVMContext diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 9e104021e5fe..bc39b3c35051 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -47,14 +47,14 @@ def is_primitive(call): def is_device_copy(func): """ - Check if the current relay expression is shape_of call. We can simply check - the body of it if it is a function becase the shape_of op is opaque. + Check if the current relay expression is a device copy call. We can simply check + the body of it if it is a function becase the device_copy op is opaque. """ if isinstance(func, Function): body = func.body return isinstance(body, expr.Call) and body.op == op.get("device_copy") if isinstance(func, expr.Call): - return body.op == op.get("device_copy") + return func.op == op.get("device_copy") return False diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 0097a4eed9dc..8bc551be0ff1 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -47,12 +47,11 @@ def test_dyn_broadcast_to(): dyn_shape = (1, ) * rank ref_res = np.broadcast_to(x, dyn_shape) for target, ctx in tvm.testing.enabled_targets(): - if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) @tvm.testing.uses_gpu diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index bab4869b3fe0..15b6b7acd7e9 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -52,7 +52,6 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa func = relay.Function([x, scale_h_var, scale_w_var], z) for target, ctx in tvm.testing.enabled_targets(): - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 193de85a5242..d6a2806719ab 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -28,8 +28,6 @@ def verify_func(func, data, ref_res): assert isinstance(data, list) for target, ctx in tvm.testing.enabled_targets(): - #TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 226bbfe2678e..eb804fe430e3 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -60,7 +60,6 @@ def verify_resize(dshape, scale, method, layout): func = relay.Function([x, size_var], z) for target, ctx in tvm.testing.enabled_targets(): - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) From 1bac1cf52dec3c0e33f732eb6295101040365997 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 26 Aug 2020 20:48:20 +0000 Subject: [PATCH 17/21] remove GetParamsContext --- include/tvm/runtime/vm/vm.h | 3 --- python/tvm/relay/transform/memory_alloc.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index ba2585b49c3d..38ccd2b324c0 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -248,9 +248,6 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief Run VM dispatch loop. */ void RunLoop(); - /*! \brief Get device context for params. */ - TVMContext GetParamsContext() const; - /*! \brief Get context from the context list based on a given device type. */ TVMContext GetContext(Index device_type) const; diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index bc39b3c35051..7e191f0261c5 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -302,6 +302,7 @@ def visit_call(self, call): return self.device_copy(new_args[0], TVMContext(attr.src_dev_type, 0), TVMContext(attr.dst_dev_type, 0)) + if self.is_dynamic(ret_type): # Handle dynamic case. return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) From b416675d434842ed4d2ec380323cbf1c7d116ec4 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 28 Aug 2020 04:05:56 +0000 Subject: [PATCH 18/21] fix comments and add doc for context analysis --- python/tvm/relay/transform/memory_alloc.py | 15 ++++++------ src/relay/analysis/context_analysis.cc | 27 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 7e191f0261c5..e6f17f996bbf 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -32,7 +32,7 @@ from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu from ..op.memory import alloc_storage -from ..analysis import context_analysis as _context_analysis +from ..analysis import context_analysis from ..._ffi.runtime_ctypes import TVMContext def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): @@ -85,7 +85,7 @@ def is_reshape_only(func): class ManifestAllocPass(ExprMutator): """A pass for explicitly manifesting all memory allocations in Relay.""" - def __init__(self, target_host, context_analysis): + def __init__(self, target_host, context_analysis_map): self.invoke_tvm = op.vm.invoke_tvm_op self.shape_func = op.vm.shape_func self.shape_of = op.vm.shape_of @@ -94,13 +94,13 @@ def __init__(self, target_host, context_analysis): self.target_host = target_host self.default_context = cpu(0) self.compute_dtype = "int64" - self.context_analysis = context_analysis + self.context_analysis_map = context_analysis_map super().__init__() def get_context(self, exp): """Get the context of a given expression""" - assert exp in self.context_analysis, exp.astext(False) - val = self.context_analysis[exp] + assert exp in self.context_analysis_map, exp.astext(False) + val = self.context_analysis_map[exp] # val[0], val[1] are device_type and device_id, respectively. # We don't need to unpack after porting this pass to C++. assert len(val) == 2 @@ -339,6 +339,7 @@ def _annotator(exp): @module_pass(opt_level=0) class ManifestAlloc: """The explicit pass wrapper around ManifestAlloc.""" + # TODO(zhiics, jroesch) Port this pass to C++. def __init__(self, target_host, targets): self.target_host = target_host self.targets = targets @@ -356,13 +357,13 @@ def transform_module(self, mod, _): fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) else: fallback_ctx = cpu(0) - ca = _context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) + ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) else: if isinstance(self.targets, dict): dev = list(self.targets.keys())[0] else: dev, _ = self.targets.items()[0] - ca = _context_analysis(mod, nd.context(dev.value)) + ca = context_analysis(mod, nd.context(dev.value)) # The following code can be used for debugging the module after # annotation. diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 3740aa2f3afb..0e1296768d82 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -20,6 +20,33 @@ /*! * \file src/relay/analysis/context_analysis.cc * \brief A pass for analyzing device attribute of each IR node. + * + * We use union-find data structures to analyze the context information of each + * sub-expression in a Relay program in this pass. Only the device copy node in + * Relay directly contains bidiretional device information. We use it to + * bidirectionally propagate the device info of its inputs and outputs. + * + * However, to support dynamism (e.g dynamic inputs), Relay introduces several + * concepts to compute the shape of tensors and operators at runtime, i.e. + * shape_of, shape_func, and reshape_tensor. These nodes are also referred to as + * VM dialects as we have native VM instructions for them. These dialects are + * intrinsically CPU friendly, therefore, they are only designed to be + * executed on CPU. We, hence, unify their inputs and outputs to CPU as well. + * Note the input of shape_of is a tensor and we only need the tensor shape. + * Therefore, the input could be sitting on GPU as well since no real data is + * needed. The context of the input would be propagated from its other + * consumers or fallback to the default device. + * + * Another type of dialect is used fo memory allocation, namely, alloc_storage + * and alloc_tensor. alloc_storage contains a context field to indicate where + * the chunk of memory is allocated. Therefore, we unify the context of + * alloc_storage with the context field. Other inputs, such as size and + * alignment, are left on CPU. + * + * Based on the above rules, we keep unifying the connected expressions and + * propagating their device information. An error will be raised whenever there + * is a unification conflict. All IR nodes that are not propagated with device + * context will fallback to the specified device. */ #include From 07e70914d2c885a17ca6841d016c7d8eb13d8365 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 31 Aug 2020 00:17:21 +0000 Subject: [PATCH 19/21] cache context --- python/tvm/runtime/vm.py | 10 +++++---- src/runtime/vm/executable.cc | 3 ++- src/runtime/vm/vm.cc | 40 +++++++++++++++++------------------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 4bfc5fd0a8a5..69d94b3f2c84 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -309,11 +309,13 @@ def _setup_ctx(self, ctx, memory_cfg): """Init context and allocators.""" ctxs = ctx if not isinstance(ctx, (list, tuple)): - assert isinstance(ctx, tvm.runtime.TVMContext) + if not isinstance(ctx, tvm.runtime.TVMContext): + raise TypeError("ctx is expected to be TVMContex") ctxs = [ctx] - # CPU is required for executing shape functions - if ctx.device_type != tvm.cpu(0).device_type: - ctxs.append(tvm.cpu()) + + # CPU is required for executing shape functions + if not any(c.device_type == tvm.cpu().device_type for c in ctxs): + ctxs.append(tvm.cpu()) default_alloc_type = VirtualMachine.POOLED_ALLOCATOR if memory_cfg is None: diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index cad145d9c707..cc1dc8dd19e5 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -644,7 +644,8 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); } case Opcode::AllocStorage: { - DCHECK_GE(instr.fields.size(), 6U); + // Number of fields = 7 + DCHECK_GE(instr.fields.size(), 7U); Index allocation_size = instr.fields[0]; Index alignment = instr.fields[1]; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 06a706e6ac7b..3dcdd065b668 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -77,7 +77,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { for (size_t i = 0; i < adt.size(); i++) { ret.push_back(CopyTo(adt[i], ctx)); } - return ADT(0, ret.begin(), ret.end()); + return ADT(adt->tag, ret.begin(), ret.end()); } } @@ -161,11 +161,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { - TVMContext ctx; - int device_type = vm_func.params_device_type[i - 1]; - ctx.device_type = DLDeviceType(device_type); - // TODO(zhiics) Use virtual device id - ctx.device_id = 0; + Index device_type = vm_func.params_device_type[i - 1]; + DLContext ctx = GetContext(device_type); ObjectRef obj = CopyTo(args[i], ctx); func_args[i - 1] = obj; } @@ -178,15 +175,13 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } } -TVMContext VirtualMachine::GetContext(Index device_type) const { - CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; - - const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&device_type](const TVMContext& c) { - return device_type == static_cast(c.device_type); - }); +inline TVMContext VirtualMachine::GetContext(Index device_type) const { + CHECK_GE(ctxs_.size(), device_type) << "ctxs_ list doesn't contain device:" << device_type; - CHECK(cit != ctxs_.end()) << "device type " << device_type << " not found int the context list."; - return *cit; + auto ctx = ctxs_[device_type]; + CHECK_EQ(static_cast(ctx.device_type), device_type) + << "device type " << device_type << " has not been initialized int the context list."; + return ctx; } void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { @@ -294,7 +289,14 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { void VirtualMachine::Init(const std::vector& ctxs, const std::vector& alloc_types) { CHECK_EQ(ctxs.size(), alloc_types.size()); - ctxs_ = ctxs; + // Cache the context + for (const auto& it : ctxs) { + auto dev_type = static_cast(it.device_type); + if (ctxs_.size() <= dev_type) { + ctxs_.resize(dev_type + 1); + } + ctxs_[dev_type] = it; + } for (size_t i = 0; i < ctxs.size(); ++i) { auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]); allocators_.emplace(ctxs[i], alloc); @@ -484,9 +486,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocTensorReg: { - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; + DLContext cpu_ctx = GetContext(static_cast(kDLCPU)); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_ctx)); auto shape = ToShape(shape_tensor); @@ -566,9 +566,7 @@ void VirtualMachine::RunLoop() { } } case Opcode::ReshapeTensor: { - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; + DLContext cpu_ctx = GetContext(static_cast(kDLCPU)); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); // Read the shape from shape tensor From df47ba4c01cd256aca89c4d8c481bcef48f6b59e Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 2 Sep 2020 22:09:21 +0000 Subject: [PATCH 20/21] cache allocator --- include/tvm/runtime/vm/vm.h | 4 ++-- src/runtime/vm/vm.cc | 23 +++++++++++------------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 38ccd2b324c0..e9f51de611b6 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -277,8 +277,8 @@ class VirtualMachine : public runtime::ModuleNode { std::unordered_map> inputs_; /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs_; - /*! \brief The mapping from TVM context to memory allocator. */ - std::unordered_map allocators_; + /*! \brief The cached memory allocators. */ + std::vector allocators_; /*! * \brief The constant pool for runtime. It caches the device dependent * object to avoid rellocation of constants during inference. diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 3dcdd065b668..aeee137530b1 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -290,16 +290,15 @@ void VirtualMachine::Init(const std::vector& ctxs, const std::vector& alloc_types) { CHECK_EQ(ctxs.size(), alloc_types.size()); // Cache the context - for (const auto& it : ctxs) { - auto dev_type = static_cast(it.device_type); + for (size_t i = 0; i < ctxs.size(); i++) { + auto dev_type = static_cast(ctxs[i].device_type); + auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]); if (ctxs_.size() <= dev_type) { ctxs_.resize(dev_type + 1); + allocators_.resize(dev_type + 1); } - ctxs_[dev_type] = it; - } - for (size_t i = 0; i < ctxs.size(); ++i) { - auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]); - allocators_.emplace(ctxs[i], alloc); + ctxs_[dev_type] = ctxs[i]; + allocators_[dev_type] = alloc; } } @@ -527,11 +526,11 @@ void VirtualMachine::RunLoop() { << ", device_type=" << instr.alloc_storage.device_type; auto storage_obj = SimpleObjAllocator().make_object(); - auto ctx = GetContext(instr.alloc_storage.device_type); - auto it = allocators_.find(ctx); - CHECK(it != allocators_.end()) - << "Did you forget to init the VirtualMachine with contexts?"; - auto alloc = it->second; + auto dev_type = instr.alloc_storage.device_type; + CHECK_LT(static_cast(dev_type), allocators_.size()) + << "Memory allocator for device " << dev_type << " has not been initialized"; + auto* alloc = allocators_[dev_type]; + CHECK(alloc) << "Did you forget to init the VirtualMachine with contexts?"; storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); From 8f572b3a38c0b67788bb41e6ec3cf35d471dc172 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 3 Sep 2020 03:06:38 +0000 Subject: [PATCH 21/21] rebase and fix comments --- python/tvm/runtime/vm.py | 3 ++- src/relay/analysis/context_analysis.cc | 6 +----- src/relay/backend/vm/compiler.cc | 14 ++++++-------- tests/python/relay/benchmarking/benchmark_vm.py | 2 +- tests/python/relay/test_adt.py | 2 +- tests/python/relay/test_any.py | 3 +-- tests/python/relay/test_pass_context_analysis.py | 16 ++++++++-------- tests/python/relay/test_vm_serialization.py | 2 +- .../python/unittest/test_runtime_vm_profiler.py | 4 ++-- 9 files changed, 23 insertions(+), 29 deletions(-) diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 69d94b3f2c84..fbc7a7d7b71e 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -310,7 +310,8 @@ def _setup_ctx(self, ctx, memory_cfg): ctxs = ctx if not isinstance(ctx, (list, tuple)): if not isinstance(ctx, tvm.runtime.TVMContext): - raise TypeError("ctx is expected to be TVMContex") + raise TypeError("ctx is expected to be TVMContext or \ + List[TVMContext]") ctxs = [ctx] # CPU is required for executing shape functions diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 0e1296768d82..bbea0399c117 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -451,7 +451,6 @@ class ContextAnalyzer : public ExprVisitor { inps.push_back(fn->params[0]); outs.push_back(call->op); Expr body = fn->body; - // outs.push_back(fn->body); CHECK(body->IsInstance() && IsDeviceCopy(body)); Call call_body = Downcast(body); attrs = call_body->attrs.as(); @@ -715,10 +714,7 @@ PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, return ret; } -TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis") - .set_body_typed([](IRModule mod, TVMContext default_context) { - return ContextAnalysisPacked(mod, default_context); - }); +TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis").set_body_typed(ContextAnalysisPacked); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b436514b59b7..18b23c42c6ea 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -531,14 +531,12 @@ class VMFunctionCompiler : ExprFunctor { const auto& it = targets_.begin(); target = (*it).second; } else { - if (expr_device_map_.count(func) == 0 || - targets_.count(expr_device_map_[func].device_type) == 0) { - int fallback_dev = GetFallbackDevice(); - auto dev_name = runtime::DeviceName(fallback_dev); - if (expr_device_map_.count(func) == 0) { - LOG(WARNING) << "The function is not annotated. Fallback to " << dev_name; - } - target = CreateDefaultTarget(fallback_dev); + CHECK_GT(expr_device_map_.count(func), 0U) + << "Found not annotated expression, please make sure " + "context analysis has been executed"; + int dev_type = expr_device_map_[func].device_type; + if (targets_.count(dev_type) == 0) { + target = CreateDefaultTarget(dev_type); } else { target = targets_[expr_device_map_[func].device_type]; } diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index 073ad6a4ca05..4fcf39d0aae2 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -79,7 +79,7 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', # random input data = np.random.uniform(size=data_shape).astype(dtype) - for target, ctx in testing.ctx_list(): + for target, ctx in testing.enabled_targets(): tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)), params, target, ctx, dtype) vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 48f2292140b1..d0e010570c8a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -721,7 +721,7 @@ def test_iterate(): def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): for kind in ["debug", "vm"]: - for target, ctx in testing.ctx_list(): + for target, ctx in testing.enabled_targets(): if kind == "debug" and ctx.device_type != tvm.cpu().device_type: continue ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 38ceb2e92e55..e33e2679fe5d 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,7 +23,6 @@ from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type import tvm.topi.testing -from tvm.relay.testing.config import ctx_list def int32(val): return relay.const(val, 'int32') @@ -37,7 +36,7 @@ def any_dims(ndim): def check_result(args, mod, expected, flatten=False, assert_shape=False, only_vm=False): for kind in ["debug", "vm"]: - for tgt, ctx in ctx_list(): + for tgt, ctx in tvm.testing.enabled_targets(): if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type): continue diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py index aadb640ac251..e54682be7871 100644 --- a/tests/python/relay/test_pass_context_analysis.py +++ b/tests/python/relay/test_pass_context_analysis.py @@ -21,12 +21,12 @@ import tvm from tvm import relay -from tvm.relay import expr as _expr, transform +from tvm.relay import expr as _expr from tvm.relay.analysis import context_analysis def test_device_copy(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() @@ -49,7 +49,7 @@ def test_device_copy(): def test_shape_func(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() @@ -78,7 +78,7 @@ def test_shape_func(): def test_vm_shape_of(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() @@ -96,7 +96,7 @@ def test_vm_shape_of(): def test_alloc_storage(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() @@ -126,7 +126,7 @@ def test_alloc_storage(): def test_alloc_tensor(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() @@ -155,7 +155,7 @@ def test_alloc_tensor(): def test_vm_reshape_tensor(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return x = relay.var("x", shape=(2, 8), dtype="float32") @@ -181,7 +181,7 @@ def test_vm_reshape_tensor(): def test_dynamic_input(): - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: return mod = tvm.IRModule() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 6486b707fcec..b304c435fdfa 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -307,7 +307,7 @@ def test_dynamic_bcast(): x_data = np.random.uniform(size=(1, 2)).astype(dtype) y_data = np.random.uniform(size=(3, 2)).astype(dtype) res_np = np.add(x_data, y_data) - for target, ctx in testing.ctx_list(): + for target, ctx in testing.enabled_targets(): res = get_serialized_output(mod, *(x_data, y_data), target=target, ctx=ctx) tvm.testing.assert_allclose(res.asnumpy(), res_np) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index cbcb022589e6..9e484357cf3e 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -18,14 +18,14 @@ from tvm.runtime import profiler_vm from tvm import relay -from tvm.relay.testing import resnet, ctx_list +from tvm.relay.testing import resnet, enabled_targets def test_basic(): mod, params = resnet.get_workload() if not profiler_vm.enabled(): return - for target, ctx in ctx_list(): + for target, ctx in enabled_targets(): exe = relay.vm.compile(mod, target, params=params) vm = profiler_vm.VirtualMachineProfiler(exe, ctx)