diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index e4b39da85206..123b7e395faa 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -31,6 +31,7 @@ #include #include #include +#include #include @@ -419,6 +420,17 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); */ TVM_DLL Pass SimplifyExpr(); +/*! + * \brief A pass for manifesting explicit memory allocations and rewriting + * specific dialects. + * + * \param target_host The target used by the host for compliation. + * \param targets The device type and target pairs for compliation. + * + * \return The pass. + */ +TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); + } // namespace transform /*! diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 97f6d1cb60c0..89c8fcb17d73 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -61,7 +61,6 @@ from .scope_builder import ScopeBuilder # Load Memory Passes -from .transform import memory_alloc from .transform import memory_plan # Required to traverse large programs diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 1d0ea176b16f..ca9996aeaaae 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,3 @@ # transformation passes from .transform import * from .recast import recast -from . import memory_alloc diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py deleted file mode 100644 index 66528c861788..000000000000 --- a/python/tvm/relay/transform/memory_alloc.py +++ /dev/null @@ -1,389 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks -""" -A pass for manifesting explicit memory allocations. -""" -import numpy as np - -from tvm.ir.transform import PassContext, module_pass -from tvm.relay.transform import InferType -from tvm import nd, container -from ..function import Function -from ..expr_functor import ExprVisitor, ExprMutator -from ..scope_builder import ScopeBuilder -from .. import op -from ... import DataType, register_func -from .. import ty, expr -from ..backend import compile_engine -from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type -from ... import cpu -from ..op.memory import alloc_storage -from ..analysis import context_analysis -from ..._ffi.runtime_ctypes import TVMContext - - -def alloc_tensor(storage, shape, dtype="float32", assert_shape=None): - offset = expr.const(0, dtype="int64") - return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) - - -def is_primitive(call): - return ( - hasattr(call, "op") - and hasattr(call.op, "attrs") - and hasattr(call.op.attrs, "Primitive") - and int(call.op.attrs.Primitive) == 1 - ) - - -def is_device_copy(func): - """ - Check if the current relay expression is a device copy call. We can simply check - the body of it if it is a function becase the device_copy op is opaque. - """ - if isinstance(func, Function): - body = func.body - return isinstance(body, expr.Call) and body.op == op.get("device_copy") - if isinstance(func, expr.Call): - return func.op == op.get("device_copy") - return False - - -class CheckReshapeOnly(ExprVisitor): - """A pass to check if the fused op contains only reshape ops.""" - - def __init__(self): - super().__init__() - self._reshape_ops = [ - op.get("reshape"), - op.get("contrib_reverse_reshape"), - op.get("dyn.reshape"), - ] - self.reshape_only = True - - def visit_call(self, call): - if not self.reshape_only: - return - if call.op not in self._reshape_ops: - self.reshape_only = False - for arg in call.args: - self.visit(arg) - - def visit_var(self, var): - var_type = var.checked_type - if not isinstance(var_type, ty.TensorType): - self.reshape_only = False - - -def is_reshape_only(func): - """Check if the primitive function contains only reshape ops.""" - check = CheckReshapeOnly() - check.visit(func) - return check.reshape_only - - -class ManifestAllocPass(ExprMutator): - """A pass for explicitly manifesting all memory allocations in Relay.""" - - def __init__(self, target_host, context_analysis_map): - self.invoke_tvm = op.vm.invoke_tvm_op - self.shape_func = op.vm.shape_func - self.shape_of = op.vm.shape_of - self.reshape_tensor = op.vm.reshape_tensor - self.scopes = [ScopeBuilder()] - self.target_host = target_host - self.default_context = cpu(0) - self.compute_dtype = "int64" - self.context_analysis_map = context_analysis_map - super().__init__() - - def get_context(self, exp): - """Get the context of a given expression""" - assert exp in self.context_analysis_map, exp.astext(False) - val = self.context_analysis_map[exp] - # val[0], val[1] are device_type and device_id, respectively. - # We don't need to unpack after porting this pass to C++. - assert len(val) == 2 - return TVMContext(val[0].value, val[1].value) - - def device_copy(self, inp, src_ctx, dst_ctx): - """Insert a device copy node.""" - return self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) - - def current_scope(self): - return self.scopes[-1] - - def visit_tuple(self, tup): - scope = self.current_scope() - new_fields = [] - for field in tup.fields: - field = self.visit(field) - if isinstance(field, expr.Constant): - field = scope.let("const", field) - new_fields.append(field) - return expr.Tuple(new_fields) - - def compute_alignment(self, dtype): - dtype = DataType(dtype) - align = (dtype.bits // 8) * dtype.lanes - # MAGIC CONSTANT FROM device_api.h - if align < 64: - align = 64 - - return expr.const(align, dtype="int64") - - def compute_storage_in_relay(self, shape, dtype): - dtype = DataType(dtype) - els = op.prod(shape) - num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype) - num = num + expr.const(7, self.compute_dtype) - div = expr.const(8, self.compute_dtype) - return els * (num / div) - - def compute_storage(self, tensor_type): - dtype = DataType(tensor_type.dtype) - shape = [int(sh) for sh in tensor_type.shape] - size = 1 - for sh in shape: - size *= sh - size *= (dtype.bits * dtype.lanes + 7) // 8 - return expr.const(size, dtype=self.compute_dtype) - - def make_static_allocation(self, scope, tensor_type, ctx, name_hint): - """Allocate a tensor with a statically known shape.""" - shape = [int(sh) for sh in tensor_type.shape] - if len(shape) == 0: - shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype) - else: - shape = expr.const(np.array(shape), dtype=self.compute_dtype) - size = self.compute_storage(tensor_type) - alignment = self.compute_alignment(tensor_type.dtype) - dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, alignment, ctx, dtype)) - # TODO(@jroesch): There is a bug with typing based on the constant shape. - tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) - return scope.let("tensor_{0}".format(name_hint), tensor) - - def visit_let(self, let): - scope = ScopeBuilder() - - self.scopes.append(scope) - while isinstance(let, expr.Let): - new_val = self.visit(let.value) - scope.let(let.var, new_val) - let = let.body - - new_body = self.visit(let) - scope.ret(new_body) - self.scopes.pop() - - return scope.get() - - def emit_shape_func(self, scope, func, new_args): - """Insert the shape function given a primitive function.""" - shape_func_ins = [] - engine = compile_engine.get() - cfunc = engine.lower_shape_func(func, self.target_host) - input_states = cfunc.shape_func_param_states - - is_inputs = [] - input_pos = 0 - cpu_ctx = nd.cpu(0) - for i, (arg, state) in enumerate(zip(new_args, input_states)): - state = int(state) - # Pass Shapes - if state == 2: - for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): - sh_of = self.visit(self.shape_of(subexp)) - shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos + j), sh_of)) - input_pos += 1 - is_inputs.append(0) - # Pass Inputs - elif state == 1: - new_arg = self.visit(arg) - ctx = self.get_context(arg) - if ctx.device_type != cpu_ctx.device_type: - new_arg = self.device_copy(new_arg, ctx, cpu_ctx) - shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos), new_arg)) - input_pos += 1 - is_inputs.append(1) - else: - # TODO(@jroesch): handle 3rd case - raise Exception("unsupported shape function input state") - - out_shapes = [] - for i, out in enumerate(cfunc.outputs): - tt = ty.TensorType(out.shape, out.dtype) - # Put shape func on CPU. This also ensures that everything between - # shape_of and shape_func are on CPU. - alloc = self.make_static_allocation(scope, tt, cpu_ctx, i) - alloc = scope.let("shape_func_out_{0}".format(i), alloc) - out_shapes.append(alloc) - - shape_call = self.shape_func( - func, expr.Tuple(shape_func_ins), expr.Tuple(out_shapes), is_inputs - ) - - scope.let("shape_func", shape_call) - return out_shapes - - def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): - """Generate the code for invoking a TVM op with a dynamic shape.""" - out_shapes = self.emit_shape_func(scope, func, new_args) - - storages = [] - func_ctx = self.get_context(func) - for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)): - size = self.compute_storage_in_relay(out_shape, out_type.dtype) - alignment = self.compute_alignment(out_type.dtype) - sto = scope.let( - "storage_{i}".format(i=i), alloc_storage(size, alignment, func_ctx, out_type.dtype) - ) - storages.append(sto) - - outs = [] - sh_ty_storage = zip(out_shapes, out_types, storages) - for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): - alloc = alloc_tensor(storage, out_shape, out_type.dtype, out_type.shape) - alloc = scope.let("out_{i}".format(i=i), alloc) - outs.append(alloc) - - tuple_outs = expr.Tuple(outs) - invoke = self.invoke_tvm(func, ins, tuple_outs) - scope.let("", invoke) - return to_tuple_type(ret_type, tuple_outs.fields) - - def emit_reshape_tensor(self, scope, func, new_args, ret_type): - if self.is_dynamic(ret_type): - out_shapes = self.emit_shape_func(scope, func, new_args) - shape_expr = out_shapes[0] - else: - # constant output shape - shape = [int(dim) for dim in ret_type.shape] - shape_expr = expr.const(shape, dtype=self.compute_dtype) - return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) - - def is_dynamic(self, ret_type): - is_dynamic = ty.is_dynamic(ret_type) - # TODO(@jroesch): restore this code, more complex then it seems - # for arg in call.args: - # is_dynamic = is_dynamic or arg.checked_type.is_dynamic() - return is_dynamic - - def visit_call(self, call): - if is_primitive(call): - # Because we are in ANF we do not need to visit the arguments. - scope = self.current_scope() - new_args = [self.visit(arg) for arg in call.args] - - ins = expr.Tuple(new_args) - ret_type = call.checked_type - out_types = flatten_tuple_type(ret_type) - - if is_reshape_only(call.op): - # Handle fused op that only contains reshape op - return self.emit_reshape_tensor(scope, call.op, new_args, ret_type) - - if is_device_copy(call.op): - # Handle device copy op - if isinstance(call.op, Function): - attr = call.op.body.attrs - else: - attr = call.attr - return self.device_copy( - new_args[0], TVMContext(attr.src_dev_type, 0), TVMContext(attr.dst_dev_type, 0) - ) - - if self.is_dynamic(ret_type): - # Handle dynamic case. - return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) - - # Handle static case. - outs = [] - for i, out_ty in enumerate(out_types): - ctx = self.get_context(call) - assert isinstance(ctx, TVMContext) - out = self.make_static_allocation(scope, out_ty, ctx, i) - outs.append(out) - - output = expr.Tuple(outs) - invoke = self.invoke_tvm(call.op, ins, output) - scope.let("", invoke) - return to_tuple_type(ret_type, output.fields) - return super().visit_call(call) - - -def mk_analysis_annotator(results): - """Pretty print the annotated relay program with device info""" - - def _annotator(exp): - if exp in results: - val = results[exp] - assert len(val) == 2 - ctx = TVMContext(val[0].value, val[1].value) - return f"<{ctx}>" - else: - return "" - - return _annotator - - -@module_pass(opt_level=0) -class ManifestAlloc: - """The explicit pass wrapper around ManifestAlloc.""" - - # TODO(zhiics, jroesch) Port this pass to C++. - def __init__(self, target_host, targets): - self.target_host = target_host - self.targets = targets - - def transform_module(self, mod, _): - """Invokes the pass""" - # TODO(@jroesch): Is there a way to do one shot initialization? - # can we have def pass_init? - mod.import_from_std("core.rly") - mod = InferType()(mod) - - assert isinstance(self.targets, (dict, container.Map)) - if len(self.targets) > 1: - pass_ctx = PassContext.current() - if "relay.fallback_device_type" in pass_ctx.config: - fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) - else: - fallback_ctx = cpu(0) - ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) - else: - if isinstance(self.targets, dict): - dev = list(self.targets.keys())[0] - else: - dev, _ = self.targets.items()[0] - ca = context_analysis(mod, nd.context(dev.value)) - - # The following code can be used for debugging the module after - # annotation. - # print(mod.astext(show_meta_data=False, annotate=mk_analysis_annotator(ca))) - - gv_funcs = mod.functions - for gv, f in gv_funcs.items(): - ea = ManifestAllocPass(self.target_host, ca) - f = ea.visit(f) - mod.update_func(gv, f) - return mod - - -register_func("relay.transform.ManifestAlloc", ManifestAlloc) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8fbe31edce3d..2cc387221f74 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -58,12 +58,6 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) { - auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); - ICHECK(f != nullptr) << "unable to load allocation manifestation pass"; - return (*f)(target_host, targets); -} - Pass MemoryPlan() { auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); ICHECK(f != nullptr) << "unable to load the memory planning pass"; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc new file mode 100644 index 000000000000..360778e1723b --- /dev/null +++ b/src/relay/transforms/memory_alloc.cc @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/memory_alloc.cc + * \brief A pass for manifesting explicit memory allocations. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../backend/compile_engine.h" +#include "let_list.h" +#include "pattern_utils.h" + +using namespace tvm::runtime; + +namespace tvm { +namespace relay { + +extern Expr ToTupleType(const Type& ty, const std::vector& exprs); +extern std::vector FromTupleType(const Type& type, const Expr& expr); +extern std::vector FlattenTupleType(const Type& type); + +using AnalysisResultMap = + std::unordered_map; + +inline Constant MakeConstant(const std::vector& value) { + return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); +} + +inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, + Array assert_shape) { + auto f = runtime::Registry::Get("relay.op.memory._make.alloc_tensor"); + CHECK(f != nullptr) << "unable to find alloc_tensor op"; + auto offset = MakeConstantScalar(DataType::Int(64), 0); + return (*f)(storage, offset, shape, dtype, assert_shape); +} + +// A pass to check if the fused op contains only reshape ops. +class CheckReshapeOnly : public ExprVisitor { + public: + CheckReshapeOnly() + : reshape_(Op::Get("reshape")), + contr_reshape_(Op::Get("contrib_reverse_reshape")), + dyn_reshape_(Op::Get("dyn.reshape")) {} + + void VisitExpr_(const CallNode* cn) final { + if (!reshape_only) return; + if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) { + reshape_only = false; + } + for (auto arg : cn->args) ExprVisitor::VisitExpr(arg); + } + + void VisitExpr_(const VarNode* vn) final { + if (!vn->checked_type_->IsInstance()) { + reshape_only = false; + } + } + + const Op& reshape_; + const Op& contr_reshape_; + const Op& dyn_reshape_; + bool reshape_only{true}; +}; + +// Check if the primitive function contains only reshape ops. +bool IsReshapeOnly(const Expr& expr) { + auto check = CheckReshapeOnly(); + check.VisitExpr(expr); + return check.reshape_only; +} + +class DialectRewriter : public ExprMutator { + public: + DialectRewriter(const Target& target_host, const AnalysisResultMap& context_analysis_map) + : target_host_(target_host), + context_analysis_map_(context_analysis_map), + device_copy_(runtime::Registry::Get("relay.op._make.device_copy")), + invoke_tvm_(runtime::Registry::Get("relay.op.vm.invoke_tvm_op")), + alloc_storage_(runtime::Registry::Get("relay.op.memory._make.alloc_storage")), + shape_func_(runtime::Registry::Get("relay.op.vm.shape_func")), + shape_of_(runtime::Registry::Get("relay.op.vm.shape_of")), + reshape_tensor_(runtime::Registry::Get("relay.op.vm.reshape_tensor")), + prod_(runtime::Registry::Get("relay.op._make.prod")), + divide_(runtime::Registry::Get("relay.op._make.divide")), + add_(runtime::Registry::Get("relay.op._make.add")), + multiply_(runtime::Registry::Get("relay.op._make.multiply")) {} + + // Get the context of an expression. + TVMContext GetContext(const Expr& expr) const { + auto it = context_analysis_map_.find(expr); + CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the context analysis map:\n" + << AsText(expr, false); + return it->second; + } + + Function Rewrite(const Function& expr) { + auto ret = ExprMutator::Mutate(expr); + return Downcast(ret); + } + + Expr VisitExpr_(const TupleNode* tn) final { + LetList& scope = scopes_.back(); + Array new_fields; + for (auto field : tn->fields) { + auto new_field = ExprMutator::Mutate(field); + if (new_field->IsInstance()) { + Var const_var("const", Type(nullptr)); + new_field = scope.Push(const_var, new_field); + } + new_fields.push_back(new_field); + } + return Tuple(new_fields); + } + + Expr VisitExpr_(const LetNode* ln) final { + scopes_.emplace_back(); + + const LetNode* let = ln; + Expr body; + while (let) { + auto new_value = ExprMutator::Mutate(let->value); + scopes_.back().Push(let->var, new_value); + body = let->body; + let = body.as(); + } + + CHECK(body.defined()); + auto new_body = ExprMutator::Mutate(body); + auto ret = scopes_.back().Get(new_body); + scopes_.pop_back(); + return ret; + } + + Expr VisitExpr_(const CallNode* cn) final { + if (IsPrimitive(cn)) { + // Because we are in ANF we do not need to visit the arguments. + LetList& scope = scopes_.back(); + std::vector new_args; + for (const auto& it : cn->args) { + new_args.push_back(ExprMutator::Mutate(it)); + } + + Tuple ins(new_args); + Type ret_type = cn->checked_type_; + std::vector out_types = FlattenTupleType(ret_type); + + // Handle fused op that only contains reshape op + if (IsReshapeOnly(cn->op)) { + Function func = Downcast(cn->op); + return EmitReshapeTensor(&scope, func, new_args, ret_type); + } + + // Handle device copy op + if (IsDeviceCopy(cn->op)) { + Attrs attr; + if (const auto* fn = cn->op.as()) { + const auto* copy_call = fn->body.as(); + CHECK(copy_call); + attr = copy_call->attrs; + } else { + attr = cn->attrs; + } + const DeviceCopyAttrs* copy_attr = attr.as(); + CHECK(copy_attr); + return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type); + } else if (IsDynamic(ret_type)) { + Function func = Downcast(cn->op); + return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type); + } else { + // Handle the static case + Array outs; + for (size_t i = 0; i < out_types.size(); ++i) { + TVMContext ctx = GetContext(GetRef(cn)); + auto out = MakeStaticAllocation(&scope, out_types[i], ctx, std::to_string(i)); + outs.push_back(out); + } + Tuple output(outs); + Expr invoke = (*invoke_tvm_)(cn->op, ins, output); + scope.Push(invoke); + return ToTupleType(ret_type, + std::vector(output->fields.begin(), output->fields.end())); + } + } else { + return ExprMutator::VisitExpr_(cn); + } + } + + private: + // Insert a device copy node. + Expr DeviceCopy(const Expr& inp, int src_ctx, int dst_ctx) { + return ExprMutator::Mutate((*device_copy_)(inp, src_ctx, dst_ctx)); + } + + // Check if a call invokes a primitive function. + bool IsPrimitive(const CallNode* call) const { + if (const auto* fn = call->op.as()) { + return fn->HasNonzeroAttr(attr::kPrimitive); + } + return false; + } + + // Check if the current relay expression is a device copy call. We can simply + // check the body of it if it is a function because the device_copy op is opaque. + bool IsDeviceCopy(const Expr& expr) const { + if (const auto* fn = expr.as()) { + auto body = fn->body; + const CallNode* call = body.as(); + return call && call->op == Op::Get("device_copy"); + } else if (const CallNode* cn = expr.as()) { + return cn->op == Op::Get("device_copy"); + } else { + return false; + } + } + + Expr ComputeAlignment(const DataType& dtype) const { + int64_t align = dtype.bits() / 8 * dtype.lanes(); + if (align < 64) { + align = 64; + } + return MakeConstantScalar(DataType::Int(64), align); + } + + Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const { + auto dtype = DataType(type->dtype); + Expr els = (*prod_)(shape, Array(nullptr), false, false); + Expr num = MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); + Expr add = (*add_)(num, MakeConstantScalar(DataType::Int(64), 7)); + Expr div = MakeConstantScalar(DataType::Int(64), 8); + Expr ret = (*multiply_)(els, (*divide_)(add, div)); + return std::move(ret); + } + + Expr ComputeStorage(const TensorType& type) { + int64_t size = 1; + for (auto it : type->shape) { + auto val = it.as(); + CHECK(val); + size *= val->value; + } + size *= (type->dtype.bits() * type->dtype.lanes() + 7) / 8; + return std::move(MakeConstantScalar(DataType::Int(64), size)); + } + + // Allocate a tensor with a statically known shape. + Var MakeStaticAllocation(LetList* scope, const TensorType& type, TVMContext ctx, + String name_hint) { + std::vector int_shape; + for (auto it : type->shape) { + const auto* imm = it.as(); + CHECK(imm) << "expect static int shape"; + int_shape.push_back(imm->value); + } + Expr shape = MakeConstant(int_shape); + Expr size = ComputeStorage(type); + Expr alignment = ComputeAlignment(type->dtype); + // Run type inference later to get the correct type. + Var var("storage_" + name_hint, Type(nullptr)); + Expr value = (*alloc_storage_)(size, alignment, ctx, type->dtype); + auto sto = scope->Push(var, value); + + // TODO(@jroesch): There is a bug with typing based on the constant shape. + auto tensor = AllocTensor(sto, shape, type->dtype, type->shape); + Var tensor_var("tensor_" + name_hint, Type(nullptr)); + return scope->Push(tensor_var, tensor); + } + + // Insert the shape function given a primitive function. + Array EmitShapeFunc(LetList* scope, const Function& func, + const std::vector& new_args) { + Array shape_func_ins; + auto engine = CompileEngine::Global(); + CCacheKey key(func, target_host_); + auto cfunc = engine->LowerShapeFunc(key); + auto input_states = cfunc->shape_func_param_states; + + Array is_inputs; + int input_pos = 0; + TVMContext cpu_ctx = default_context_; + CHECK_EQ(new_args.size(), input_states.size()); + for (size_t i = 0; i < new_args.size(); ++i) { + Expr arg = new_args[i]; + Type ty; + if (const auto* vn = arg.as()) { + ty = vn->type_annotation; + } else { + ty = arg->checked_type(); + } + int state = input_states[i]->value; + // Pass Shapes + if (state == 2) { + std::vector exprs = FromTupleType(ty, arg); + for (size_t j = 0; j < exprs.size(); ++j) { + Expr sh_of = ExprMutator::Mutate((*shape_of_)(exprs[j])); + Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); + shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); + input_pos++; + } + is_inputs.push_back(0); + } else if (state == 1) { + auto new_arg = ExprMutator::Mutate(arg); + auto ctx = GetContext(arg); + if (ctx.device_type != cpu_ctx.device_type) { + new_arg = DeviceCopy(new_arg, ctx.device_type, cpu_ctx.device_type); + } + Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); + shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); + input_pos++; + is_inputs.push_back(1); + } else { + // TODO(@jroesch): handle 3rd case + LOG(FATAL) << "unsupported shape function input state"; + } + } + + Array out_shapes; + for (size_t i = 0; i < cfunc->outputs.size(); ++i) { + auto out = cfunc->outputs[i]; + auto tt = TensorType(out->shape, out->dtype); + // Put shape func on CPU. This also ensures that everything between + // shape_of and shape_func are on CPU. + auto alloc = MakeStaticAllocation(scope, tt, cpu_ctx, std::to_string(i)); + Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); + alloc = scope->Push(shape_func_out_var, alloc); + out_shapes.push_back(alloc); + } + auto shape_call = (*shape_func_)(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs); + Var shape_func_var("shape_func", Type(nullptr)); + scope->Push(shape_func_var, shape_call); + return out_shapes; + } + + // Generate the code for invoking a TVM op with a dynamic shape. + Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins, + const std::vector& new_args, const std::vector& out_types, + const Type& ret_type) { + auto out_shapes = EmitShapeFunc(scope, func, new_args); + std::vector storages; + auto func_ctx = GetContext(func); + CHECK_EQ(out_shapes.size(), out_types.size()); + for (size_t i = 0; i < out_shapes.size(); ++i) { + auto out_shape = out_shapes[i]; + auto out_type = out_types[i]; + auto size = ComputeStorageInRelay(out_shape, out_type); + auto alignment = ComputeAlignment(out_type->dtype); + Var sto_var("storage_" + std::to_string(i), Type(nullptr)); + auto val = (*alloc_storage_)(size, alignment, func_ctx, out_type->dtype); + storages.push_back(scope->Push(sto_var, val)); + } + + Array outs; + for (size_t i = 0; i < storages.size(); ++i) { + auto out_shape = out_shapes[i]; + auto out_type = out_types[i]; + auto storage = storages[i]; + auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); + Var out_var("out_" + std::to_string(i), Type(nullptr)); + outs.push_back(scope->Push(out_var, alloc)); + } + + Tuple tuple_outs(outs); + auto invoke = (*invoke_tvm_)(func, ins, tuple_outs); + scope->Push(invoke); + return ToTupleType(ret_type, + std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); + } + + Expr EmitReshapeTensor(LetList* scope, const Function& func, const std::vector& new_args, + const Type& ret_type) { + TensorType ret_ty = Downcast(ret_type); + Expr shape_expr; + if (IsDynamic(ret_type)) { + auto out_shapes = EmitShapeFunc(scope, func, new_args); + shape_expr = out_shapes[0]; + } else { + std::vector shape; + for (const auto& it : ret_ty->shape) { + const auto* imm = it.as(); + CHECK(imm) << "expect static int shape"; + shape.push_back(imm->value); + } + shape_expr = MakeConstant(shape); + } + return (*reshape_tensor_)(new_args[0], shape_expr, ret_ty->shape); + } + + private: + Target target_host_; + AnalysisResultMap context_analysis_map_; + std::vector scopes_; + + // Cache the following ops + const PackedFunc* device_copy_; + const PackedFunc* invoke_tvm_; + const PackedFunc* alloc_storage_; + const PackedFunc* shape_func_; + const PackedFunc* shape_of_; + const PackedFunc* reshape_tensor_; + const PackedFunc* prod_; + const PackedFunc* divide_; + const PackedFunc* add_; + const PackedFunc* multiply_; + + runtime::DataType compute_dtype_ = runtime::DataType::Int(64); + TVMContext default_context_{kDLCPU, 0}; +}; + +namespace transform { + +Pass ManifestAlloc(Target target_host, Map targets) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "tvm::relay::transform::ManifestAlloc"; + // We need to mutate module, therefore making a copy of it. + mod.CopyOnWrite(); + mod->ImportFromStd("core.rly"); + mod = relay::transform::InferType()(mod); + + TVMContext fallback_ctx; + if (targets.size() > 1) { + auto pass_ctx = PassContext::Current(); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); + auto fallback_dev = opt_fallback_dev.value(); + CHECK_GT(fallback_dev->value, 0U); + fallback_ctx.device_type = static_cast(fallback_dev->value); + fallback_ctx.device_id = 0; + } else { + const auto& it = targets.begin(); + fallback_ctx.device_type = static_cast((*it).first->value); + fallback_ctx.device_id = 0; + } + auto ca = ContextAnalysis(mod, fallback_ctx); + + auto glob_funcs = mod->functions; + for (const auto& it : glob_funcs) { + if (auto* func_node = it.second.as()) { + auto func = GetRef(func_node); + auto rewriter = DialectRewriter(target_host, ca); + auto updated_func = rewriter.Rewrite(func); + + mod->Update(it.first, updated_func); + } + } + + mod = relay::transform::InferType()(mod); + return mod; + }, + 0, "ManifestAlloc", {}); +} + +TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc") + .set_body_typed([](Target target_host, Map targets) { + return ManifestAlloc(target_host, targets); + }); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 0b575d120e8f..9d05631a753a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -54,7 +54,6 @@ def check_result( for kind in ["debug", "vm"]: targets = targets or tvm.testing.enabled_targets() for tgt, ctx in targets: - print(tgt) if disable_targets and tgt in disable_targets: continue if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type): diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py index c960d1f90c37..546aaf51f734 100644 --- a/tests/python/relay/test_memory_passes.py +++ b/tests/python/relay/test_memory_passes.py @@ -18,7 +18,6 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay import memory_alloc def check_memory_plan(func, check_fn):