Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
* \return The corresponding token.
*/
StorageInfo GetStorage(const Expr& expr) {
auto props = GetOnDeviceProps(expr);
// See through "on_device" calls.
Expr true_expr = props.body.defined() ? props.body : expr;
Expr true_expr = IgnoreOnDevice(expr);
VisitExpr(true_expr);
auto it = storage_device_map_.find(true_expr);
ICHECK(it != storage_device_map_.end());
Expand Down
5 changes: 2 additions & 3 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
* \return The corresponding token.
*/
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
this->VisitExpr(expr);
// See through on_device calls.
auto props = GetOnDeviceProps(expr);
Expr real_expr = props.body.defined() ? props.body : expr;
Expr real_expr = IgnoreOnDevice(expr);
this->VisitExpr(real_expr);
auto it = token_map_.find(real_expr.get());
ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl
<< PrettyPrint(real_expr);
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
auto offset_register = last_register_;

// If the shape is constant then we will emit a static tensor allocation
// instruction.
auto const_shape = args[2].as<ConstantNode>();
// instruction. It may be wrapped by an on_device, but it will be on the host
// which is assumed by the alloc_tensor instruction anyway.
auto const_shape = AsIgnoringOnDevice<ConstantNode>(args[2]);

if (const_shape) {
NDArray shape = const_shape->data;
Expand All @@ -619,7 +620,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
this->VisitExpr(args[0]);
auto size_register = last_register_;

ICHECK(args[1].as<ConstantNode>());
ICHECK(args[1].as<ConstantNode>()); // Always a literal.
NDArray alignment_arr = args[1].as<ConstantNode>()->data;
ICHECK_EQ(alignment_arr->dtype.code, 0U)
<< "The dtype of constant shape must be int32 or int64, but got "
Expand Down
26 changes: 26 additions & 0 deletions src/relay/op/annotation/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
*/
OnDeviceProps GetOnDeviceProps(const Expr& expr);

/*!
* \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns
* \p expr directly.
*/
inline Expr IgnoreOnDevice(const Expr& expr) {
OnDeviceProps props = GetOnDeviceProps(expr);
return props.body.defined() ? props.body : expr;
}

/*!
* \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through
* any "on_device" annotations.
*/
template <typename NodeType>
const NodeType* AsIgnoringOnDevice(const Expr& expr) {
const auto* node = expr.as<NodeType>();
if (node != nullptr) {
return node;
}
OnDeviceProps props = GetOnDeviceProps(expr);
if (!props.body.defined()) {
return nullptr;
}
return props.body.as<NodeType>();
}

/*!
* \brief Returns \p function annotated with "param_device_types" and "result_device_type"
* attributes capturing parameter and result devices types respectively.
Expand Down
10 changes: 3 additions & 7 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
attrs->assert_shape = assert_shape;
} else {
// Look through any on_device for the shape argument expression.
Expr literal_shape = shape;
auto props = GetOnDeviceProps(literal_shape);
if (props.body.defined()) {
// See through on_device calls.
literal_shape = props.body;
}
attrs->const_shape = Downcast<Constant>(literal_shape);
const auto* constant_node = AsIgnoringOnDevice<ConstantNode>(shape);
ICHECK(constant_node);
attrs->const_shape = GetRef<Constant>(constant_node);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
Expand Down
5 changes: 2 additions & 3 deletions src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
* is it atomic?
* if so, the compute cost of the expression is bounded so it can be copy without graph mode.
*/
inline bool IsAtomic(const Expr& e) {
auto props = GetOnDeviceProps(e);
Expr true_expr = props.body.defined() ? props.body : e;
inline bool IsAtomic(const Expr& expr) {
Expr true_expr = IgnoreOnDevice(expr);
return true_expr.as<VarNode>() || true_expr.as<OpNode>() || true_expr.as<ConstructorNode>() ||
true_expr.as<GlobalVarNode>() ||
true_expr.as<ConstantNode>(); // Constant is always by reference.
Expand Down
15 changes: 14 additions & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,19 @@ def test_vm_reshape_tensor(target, dev):
check_result(target, dev, [x_np, y_np], x_np.reshape([8, 2, 8]), mod)


def test_vm_reshape_and_copy(target, dev):
"""Make sure the compiler notices the reshape result shape is a literal and can use
the immediate-mode alloc_tensor instruction instead of alloc_tensor_reg."""
x_np = np.random.uniform(size=(1, 1)).astype("float32")
x = relay.var("x", shape=(1, 1), dtype="float32")
mod = tvm.IRModule.from_expr(relay.Function([x], relay.copy(relay.reshape(x, [0, 1]))))
with tvm.transform.PassContext(opt_level=3):
exec = relay.vm.compile(mod, "llvm")
assert "alloc_tensor" in exec.bytecode
assert not "alloc_tensor_reg" in exec.bytecode
check_result(target, dev, [x_np], x_np.reshape([1, 1]), mod)


def test_vm_reshape_tuple(target, dev, x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
tup = relay.var(
"tup",
Expand Down Expand Up @@ -963,4 +976,4 @@ def test_benchmark_end_to_end_rpc():
if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))
sys.exit(pytest.main([__file__] + sys.argv[1:]))