Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
5 changes: 5 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ using OpNode = tvm::OpNode;

#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)

namespace op {
namespace annotation {
Expr on_device(Expr data, int device_type);
} // namespace annotation
} // namespace op
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
18 changes: 18 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ using PassContext = tvm::transform::PassContext;
using PassContextNode = tvm::transform::PassContextNode;
using Sequential = tvm::transform::Sequential;

/*
* \brief Function to get the device placement for an op.
*
* \return The context.device_type to be used for op expr
*/
using FTVMGetPlacement = runtime::TypedPackedFunc<int(const Expr& expr)>;

/*
* \brief Create a function pass.
*
Expand Down Expand Up @@ -431,6 +438,17 @@ TVM_DLL Pass SimplifyExpr();
*/
TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);

/*!
* \brief Annotate ops for heterogeneous execution.
*
* \param get_placement a packed function of type int(Expr) which determines the
* placement of each Expr. The returned int is the target device_type to use
* for Expr or -1 for default placement.
*
* \return The pass.
*/
TVM_DLL Pass AnnotateDevicePlacement(FTVMGetPlacement get_placement);

} // namespace transform

/*!
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,18 @@ def Defunctionalization(func, mod):
return _ffi_api.Defunctionalization(func, mod)


def AnnotateDevicePlacement(get_placement):
"""
Annotate a module with compiler_begin and compiler_end for partitioning and
heterogeneous execution.

Returns
-------
None
"""
return _ffi_api.AnnotateDevicePlacement(get_placement)


def to_cps(func, mod=None):
"""
Turn expression into CPS expression.
Expand Down
18 changes: 12 additions & 6 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@
namespace tvm {
namespace relay {

namespace op {
namespace annotation {
Expr on_device(Expr data, int device_type) {
auto attrs = make_object<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return Call(op, {data}, Attrs(attrs), {});
}
} // namespace annotation
} // namespace op

// relay.annotation.on_device
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);

TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device")
.set_body_typed([](Expr data, int device_type) {
auto attrs = make_object<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return Call(op, {data}, Attrs(attrs), {});
});
.set_body_typed(op::annotation::on_device);

RELAY_REGISTER_OP("on_device")
.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
Expand Down
79 changes: 79 additions & 0 deletions src/relay/transforms/annotate_device_placement.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file annotate_device_placement.cc
* \brief Annotate Expr with on_device indicating the device_type to use per op.
* Once the ops are annotated running the RewriteAnnotatedOps pass inserts device_copy ops
* to copy tensors to the correct device.
*/

#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

class DeviceAnnotator : public MixedModeMutator {
public:
explicit DeviceAnnotator(IRModule module, transform::FTVMGetPlacement get_placement)
: module_(module), get_placement_(get_placement) {}

private:
IRModule module_;
transform::FTVMGetPlacement get_placement_;

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expr rc = post;
const CallNode* call_node = post.as<CallNode>();
if (call_node->op.as<OpNode>()) {
int device_type = get_placement_(GetRef<Expr>(call_node));
if (device_type > 0) {
rc = relay::op::annotation::on_device(post, device_type);
}
}
return rc;
}

Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) override { return post; }
Expr Rewrite_(const TupleNode* pre, const Expr& post) override { return post; }
};

Expr AnnotateDevicePlacement(const Expr& expr, const IRModule& mod,
transform::FTVMGetPlacement get_placement) {
return DeviceAnnotator(mod, get_placement).Mutate(expr);
}

namespace transform {

Pass AnnotateDevicePlacement(FTVMGetPlacement get_placement) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(AnnotateDevicePlacement(f, m, get_placement));
};
return CreateFunctionPass(pass_func, 2, "AnnotateDevicePlacement", {});
}

TVM_REGISTER_GLOBAL("relay._transform.AnnotateDevicePlacement")
.set_body_typed(AnnotateDevicePlacement);

} // namespace transform
} // namespace relay
} // namespace tvm
104 changes: 104 additions & 0 deletions tests/python/relay/test_pass_annotate_device_placement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit test for annotating device placement."""
import os
import sys
import numpy as np
import pytest

import tvm
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.contrib import utils
from tvm import relay, tir, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
from tvm.ir import Op

# a b a b
# \/ \/
# add add
# \ /
# \ /
# mul
# / \
# c / c |
# \/ \/
# mul mul
# \ /
# \ /
# add


def get_expected_model(cpu_ctx, dev_ctx):
a = relay.var("a", shape=(2, 3))
b = relay.var("b", shape=(2, 3))
c = relay.var("c", shape=(2, 3))
add1 = relay.add(a, b)
add2 = relay.add(a, b)
mul1 = relay.annotation.on_device(relay.multiply(add1, add2), dev_ctx)
mul2 = relay.annotation.on_device(relay.multiply(mul1, c), dev_ctx)
mul3 = relay.annotation.on_device(relay.multiply(mul1, c), dev_ctx)
add3 = relay.add(mul2, mul3)
func = relay.Function([a, b, c], add3)

mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.InferType()(mod)

return mod


def get_annotated_model(cpu_ctx, dev_ctx):
a = relay.var("a", shape=(2, 3))
b = relay.var("b", shape=(2, 3))
c = relay.var("c", shape=(2, 3))
add1 = relay.add(a, b)
add2 = relay.add(a, b)
mul1 = relay.multiply(add1, add2)
mul2 = relay.multiply(mul1, c)
mul3 = relay.multiply(mul1, c)
add3 = relay.add(mul2, mul3)
func = relay.Function([a, b, c], add3)

mod = tvm.IRModule()
mod["main"] = func

def get_placement(expr):
"""This method is called for each Call node in the graph. Return the targeted
compiler for each Op or "default"
"""
target_ops = ["multiply"]
placement = -1
if isinstance(expr, Call):
if isinstance(expr.op, Op):
if expr.op.name in target_ops:
placement = dev_ctx.device_type
return placement

mod = relay.transform.AnnotateDevicePlacement(get_placement)(mod)
return mod


def test_device_placement():
ctx1 = tvm.context("cpu")
ctx2 = tvm.context("llvm")
mod = get_annotated_model(ctx1, ctx2)
expected_mod = get_expected_model(ctx1, ctx2)
assert tvm.ir.structural_equal(mod["main"], expected_mod["main"], map_free_vars=True)