diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 0e43abb54b93..4a00de802c61 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode { * * For expressions that have the function type, the virtual device describes where the result of * the call to the function or closure is stored (instead of where the function itself is stored). + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. For more details, see the documentation in + * src/relay/transforms/device_planner.cc. + * * The VirtualDevice's Target field describes how the body of the function should be compiled. * * Set to VirtualDevice::FullyUnconstrained by default. @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode { /*! * \return The virtual device (VirtualDevice). * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). + * Note that for function types, the virtual device is the device where the result of a + * call to the function is stored, not where the function itself lives. + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. + * + * See the documentation of the virtual_device_ field (above) for more details. */ VirtualDevice virtual_device() const; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 051c05dd3d01..72dc8a5c9bf9 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol"; */ constexpr const char* kParamVirtualDevice = "param_virtual_devices"; -/*! - * \brief The \p VirtualDevice which will hold the function result. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: VirtualDevice - */ -constexpr const char* kResultVirtualDevice = "result_virtual_device"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 07011ea412a3..4a40777563af 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -162,6 +162,7 @@ using MemoryScope = String; * * These operations are needed during device planning. */ + class VirtualDeviceNode : public AttrsNode { private: /*! @@ -361,6 +362,13 @@ class VirtualDeviceCache { std::unordered_set cache_; }; +/*! brief The attribute key for the virtual device. This key will be promoted to first class on + * functions. For use in the parser and printer only. + * + * Type: VirtualDevice + */ +constexpr const char* kVirtualDevice = "result_virtual_device"; + } // namespace tvm #endif // TVM_TARGET_VIRTUAL_DEVICE_H_ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 44aeeb3bdee1..b80fc3277697 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include @@ -1137,7 +1138,19 @@ class Parser { // TODO(@jroesch): attributes should never be null, they should always be empty. if (raw_attrs.size()) { - return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + // Promote kVirtualDevice to first-class + if (raw_attrs.count(kVirtualDevice)) { + ObjectRef vid = raw_attrs.at(kVirtualDevice); + ICHECK(vid.as()) + << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " + << vid->GetTypeKey(); + raw_attrs.erase(kVirtualDevice); + Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + func->virtual_device_ = vid; + return func; + } else { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } } else { return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index fdc6c37e527a..0ef45d878393 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -445,10 +445,16 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } + if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) { + Doc vid_doc; + vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device()); + params.push_back(vid_doc); + } doc << Doc::Concat(params) << ") "; if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } + doc << PrintBody(fn->body); return doc; } @@ -515,11 +521,7 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { for (const Expr& arg : op->args) { args.push_back(Print(arg)); } -#if TVM_LOG_DEBUG - for (const Type& type_arg : op->type_args) { - args.push_back(Print(type_arg)); - } -#endif + for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { args.push_back(d); } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 2ad27a0d20b0..0019b22f1a8f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. + VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes + << std::endl; if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes << "expressions are assigned with virtual device types. Either all " diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 48e93ccf654d..c66c91ecc739 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - return WithAttrs(std::move(function), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}, - {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}}); + auto func = WithAttr( + WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), + tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)); + VLOG(1) << "Annotated func: " << PrettyPrint(func); + return func; } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array param_vir } VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { - auto opt_virtual_device = function_node->GetAttr(tvm::attr::kResultVirtualDevice); - return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained()); + return function_node->virtual_device(); } VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 2352821f7bee..abc458313101 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -70,7 +70,7 @@ def test_function_on_device(): assert len(func.attrs["param_virtual_devices"]) == 2 assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA - assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA + assert func.virtual_device_.device_type_int == 2 # ie KDLCUDA if __name__ == "__main__":