Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode {
*
* For expressions that have the function type, the virtual device describes where the result of
* the call to the function or closure is stored (instead of where the function itself is stored).
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
* device of body. For more details, see the documentation in
* src/relay/transforms/device_planner.cc.
*
* The VirtualDevice's Target field describes how the body of the function should be compiled.
*
* Set to VirtualDevice::FullyUnconstrained by default.
Expand All @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode {
/*!
* \return The virtual device (VirtualDevice).
* If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained().
* Note that for function types, the virtual device is the device where the result of a
* call to the function is stored, not where the function itself lives.
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
* device of body.
*
* See the documentation of the virtual_device_ field (above) for more details.
*/
VirtualDevice virtual_device() const;

Expand Down
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol";
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

/*!
* \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: VirtualDevice
*/
constexpr const char* kResultVirtualDevice = "result_virtual_device";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
8 changes: 8 additions & 0 deletions include/tvm/target/virtual_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ using MemoryScope = String;
*
* These operations are needed during device planning.
*/

class VirtualDeviceNode : public AttrsNode<VirtualDeviceNode> {
private:
/*!
Expand Down Expand Up @@ -361,6 +362,13 @@ class VirtualDeviceCache {
std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
};

/*! brief The attribute key for the virtual device. This key will be promoted to first class on
* functions. For use in the parser and printer only.
*
* Type: VirtualDevice
*/
constexpr const char* kVirtualDevice = "result_virtual_device";

} // namespace tvm

#endif // TVM_TARGET_VIRTUAL_DEVICE_H_
15 changes: 14 additions & 1 deletion src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/virtual_device.h>

#include <fstream>

Expand Down Expand Up @@ -1137,7 +1138,19 @@ class Parser {

// TODO(@jroesch): attributes should never be null, they should always be empty.
if (raw_attrs.size()) {
return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
// Promote kVirtualDevice to first-class
if (raw_attrs.count(kVirtualDevice)) {
ObjectRef vid = raw_attrs.at(kVirtualDevice);
ICHECK(vid.as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
<< vid->GetTypeKey();
raw_attrs.erase(kVirtualDevice);
Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
func->virtual_device_ = vid;
return func;
} else {
return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
}
} else {
return relay::Function(params, body, ret_type, generics, tvm::DictAttrs());
}
Expand Down
12 changes: 7 additions & 5 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,16 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
Doc vid_doc;
vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
params.push_back(vid_doc);
}
doc << Doc::Concat(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}

doc << PrintBody(fn->body);
return doc;
}
Expand Down Expand Up @@ -515,11 +521,7 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
#if TVM_LOG_DEBUG
for (const Type& type_arg : op->type_args) {
args.push_back(Print(type_arg));
}
#endif

for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
<< std::endl;
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
Expand Down
11 changes: 6 additions & 5 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
return WithAttrs(std::move(function),
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
{tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
auto func = WithAttr(
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices));
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
Expand All @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_vir
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/op/annotation/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_function_on_device():
assert len(func.attrs["param_virtual_devices"]) == 2
assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU
assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA
assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA
assert func.virtual_device_.device_type_int == 2 # ie KDLCUDA


if __name__ == "__main__":
Expand Down