diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw index 57db5a718c74..87ce9acfae55 160000 --- a/3rdparty/vta-hw +++ b/3rdparty/vta-hw @@ -1 +1 @@ -Subproject commit 57db5a718c74a788c98120ebbe1230797be698c8 +Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 12845158a22f..2fef1c5a0306 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -36,6 +36,11 @@ using OpNode = tvm::OpNode; #define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName) +namespace op { +namespace annotation { +Expr on_device(Expr data, int device_type); +} // namespace annotation +} // namespace op } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 123b7e395faa..313bdc7ea14b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -47,6 +47,13 @@ using PassContext = tvm::transform::PassContext; using PassContextNode = tvm::transform::PassContextNode; using Sequential = tvm::transform::Sequential; +/* + * \brief Function to get the device placement for an op. + * + * \return The context.device_type to be used for op expr + */ +using FTVMGetPlacement = runtime::TypedPackedFunc; + /* * \brief Create a function pass. * @@ -431,6 +438,17 @@ TVM_DLL Pass SimplifyExpr(); */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Annotate ops for heterogeneous execution. + * + * \param get_placement a packed function of type int(Expr) which determines the + * placement of each Expr. The returned int is the target device_type to use + * for Expr or -1 for default placement. + * + * \return The pass. + */ +TVM_DLL Pass AnnotateDevicePlacement(FTVMGetPlacement get_placement); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c6df8c1e6ea2..8b945ce57cf3 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -817,6 +817,18 @@ def Defunctionalization(func, mod): return _ffi_api.Defunctionalization(func, mod) +def AnnotateDevicePlacement(get_placement): + """ + Annotate a module with compiler_begin and compiler_end for partitioning and + heterogeneous execution. + + Returns + ------- + None + """ + return _ffi_api.AnnotateDevicePlacement(get_placement) + + def to_cps(func, mod=None): """ Turn expression into CPS expression. diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 60a2e95cdcf7..adeba51ebb85 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -36,16 +36,22 @@ namespace tvm { namespace relay { +namespace op { +namespace annotation { +Expr on_device(Expr data, int device_type) { + auto attrs = make_object(); + attrs->device_type = device_type; + static const Op& op = Op::Get("on_device"); + return Call(op, {data}, Attrs(attrs), {}); +} +} // namespace annotation +} // namespace op + // relay.annotation.on_device TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); - }); + .set_body_typed(op::annotation::on_device); RELAY_REGISTER_OP("on_device") .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) diff --git a/src/relay/transforms/annotate_device_placement.cc b/src/relay/transforms/annotate_device_placement.cc new file mode 100644 index 000000000000..c5152503ab82 --- /dev/null +++ b/src/relay/transforms/annotate_device_placement.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file annotate_device_placement.cc + * \brief Annotate Expr with on_device indicating the device_type to use per op. + * Once the ops are annotated running the RewriteAnnotatedOps pass inserts device_copy ops + * to copy tensors to the correct device. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +class DeviceAnnotator : public MixedModeMutator { + public: + explicit DeviceAnnotator(IRModule module, transform::FTVMGetPlacement get_placement) + : module_(module), get_placement_(get_placement) {} + + private: + IRModule module_; + transform::FTVMGetPlacement get_placement_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + Expr rc = post; + const CallNode* call_node = post.as(); + if (call_node->op.as()) { + int device_type = get_placement_(GetRef(call_node)); + if (device_type > 0) { + rc = relay::op::annotation::on_device(post, device_type); + } + } + return rc; + } + + Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) override { return post; } + Expr Rewrite_(const TupleNode* pre, const Expr& post) override { return post; } +}; + +Expr AnnotateDevicePlacement(const Expr& expr, const IRModule& mod, + transform::FTVMGetPlacement get_placement) { + return DeviceAnnotator(mod, get_placement).Mutate(expr); +} + +namespace transform { + +Pass AnnotateDevicePlacement(FTVMGetPlacement get_placement) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(AnnotateDevicePlacement(f, m, get_placement)); + }; + return CreateFunctionPass(pass_func, 2, "AnnotateDevicePlacement", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateDevicePlacement") + .set_body_typed(AnnotateDevicePlacement); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_annotate_device_placement.py b/tests/python/relay/test_pass_annotate_device_placement.py new file mode 100644 index 000000000000..a07a11fad84e --- /dev/null +++ b/tests/python/relay/test_pass_annotate_device_placement.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit test for annotating device placement.""" +import os +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform as transform +from tvm import relay +from tvm import runtime +from tvm.contrib import utils +from tvm import relay, tir, autotvm +from tvm.relay import transform +from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple +from tvm.ir import Op + +# a b a b +# \/ \/ +# add add +# \ / +# \ / +# mul +# / \ +# c / c | +# \/ \/ +# mul mul +# \ / +# \ / +# add + + +def get_expected_model(cpu_ctx, dev_ctx): + a = relay.var("a", shape=(2, 3)) + b = relay.var("b", shape=(2, 3)) + c = relay.var("c", shape=(2, 3)) + add1 = relay.add(a, b) + add2 = relay.add(a, b) + mul1 = relay.annotation.on_device(relay.multiply(add1, add2), dev_ctx) + mul2 = relay.annotation.on_device(relay.multiply(mul1, c), dev_ctx) + mul3 = relay.annotation.on_device(relay.multiply(mul1, c), dev_ctx) + add3 = relay.add(mul2, mul3) + func = relay.Function([a, b, c], add3) + + mod = tvm.IRModule() + mod["main"] = func + mod = relay.transform.InferType()(mod) + + return mod + + +def get_annotated_model(cpu_ctx, dev_ctx): + a = relay.var("a", shape=(2, 3)) + b = relay.var("b", shape=(2, 3)) + c = relay.var("c", shape=(2, 3)) + add1 = relay.add(a, b) + add2 = relay.add(a, b) + mul1 = relay.multiply(add1, add2) + mul2 = relay.multiply(mul1, c) + mul3 = relay.multiply(mul1, c) + add3 = relay.add(mul2, mul3) + func = relay.Function([a, b, c], add3) + + mod = tvm.IRModule() + mod["main"] = func + + def get_placement(expr): + """This method is called for each Call node in the graph. Return the targeted + compiler for each Op or "default" + """ + target_ops = ["multiply"] + placement = -1 + if isinstance(expr, Call): + if isinstance(expr.op, Op): + if expr.op.name in target_ops: + placement = dev_ctx.device_type + return placement + + mod = relay.transform.AnnotateDevicePlacement(get_placement)(mod) + return mod + + +def test_device_placement(): + ctx1 = tvm.context("cpu") + ctx2 = tvm.context("llvm") + mod = get_annotated_model(ctx1, ctx2) + expected_mod = get_expected_model(ctx1, ctx2) + assert tvm.ir.structural_equal(mod["main"], expected_mod["main"], map_free_vars=True)