From 5b09b96d9be0a425e87801688ac5096e6a8bd294 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 25 Jan 2022 12:04:57 -0800 Subject: [PATCH 1/9] VStore function result virtual devices in virtual_device_ field --- include/tvm/ir/expr.h | 12 ++++++++++++ include/tvm/ir/function.h | 10 ---------- src/printer/relay_text_printer.cc | 1 + src/relay/backend/graph_plan_memory.cc | 2 ++ src/relay/op/memory/on_device.cc | 11 ++++++----- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 0e43abb54b93..4a00de802c61 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode { * * For expressions that have the function type, the virtual device describes where the result of * the call to the function or closure is stored (instead of where the function itself is stored). + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. For more details, see the documentation in + * src/relay/transforms/device_planner.cc. + * * The VirtualDevice's Target field describes how the body of the function should be compiled. * * Set to VirtualDevice::FullyUnconstrained by default. @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode { /*! * \return The virtual device (VirtualDevice). * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). + * Note that for function types, the virtual device is the device where the result of a + * call to the function is stored, not where the function itself lives. + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. + * + * See the documentation of the virtual_device_ field (above) for more details. */ VirtualDevice virtual_device() const; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 051c05dd3d01..72dc8a5c9bf9 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol"; */ constexpr const char* kParamVirtualDevice = "param_virtual_devices"; -/*! - * \brief The \p VirtualDevice which will hold the function result. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: VirtualDevice - */ -constexpr const char* kResultVirtualDevice = "result_virtual_device"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index fdc6c37e527a..8dbfb8ce01f4 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -449,6 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } + doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n"; doc << PrintBody(fn->body); return doc; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 2ad27a0d20b0..0019b22f1a8f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. + VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes + << std::endl; if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes << "expressions are assigned with virtual device types. Either all " diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 48e93ccf654d..4536725b2073 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - return WithAttrs(std::move(function), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}, - {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}}); + auto func = WithAttrs( + WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), + {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}}); + VLOG(1) << "Annotated func: " << PrettyPrint(func); + return func; } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array param_vir } VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { - auto opt_virtual_device = function_node->GetAttr(tvm::attr::kResultVirtualDevice); - return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained()); + return function_node->virtual_device(); } VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { From 657ebcdde421cf0ef5bf64d7b76bbd680a7a28d8 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 15 Feb 2022 10:42:01 -0800 Subject: [PATCH 2/9] Address Mark's 'mega nit' --- src/relay/op/memory/on_device.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 4536725b2073..c66c91ecc739 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,9 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - auto func = WithAttrs( + auto func = WithAttr( WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}}); + tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)); VLOG(1) << "Annotated func: " << PrettyPrint(func); return func; } From 4223e6005dabb5b4b21613c4c2c1efe5537855d2 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 15 Feb 2022 11:23:48 -0800 Subject: [PATCH 3/9] Promote function result virtual device to first class --- src/parser/parser.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 44aeeb3bdee1..a2fb6cc69e90 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1137,7 +1137,17 @@ class Parser { // TODO(@jroesch): attributes should never be null, they should always be empty. if (raw_attrs.size()) { - return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + // Promote "result_virtual_device" to first-class + if (raw_attrs.count("result_virtual_device")) { + ObjectRef vid = raw_attrs.at("result_virtual_device"); + // TODO(@electriclilies): check that this is a virtaul device + raw_attrs.erase("result_virtual_device"); + Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + func->virtual_device_ = vid; + return func; + } else { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } } else { return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); } From 5ae4d7f7272f1026e34f01330460583470087db4 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 15 Feb 2022 11:42:13 -0800 Subject: [PATCH 4/9] Add kVirtualDevice --- src/parser/parser.cc | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index a2fb6cc69e90..186f133f3a6c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1080,6 +1080,16 @@ class Parser { } } + /*! brief The attribute key for the virtual device. This key will be promoted to first class on + * functions and variable bindings. + * + * Type: VirtualDevice + */ + // why can't this be constexpr? + // also where to put me? + // also change to just virtual_device, no result + const char* kVirtualDevice = "result_virtual_device"; + /*! Parse a function definition without a leading keyword or identifier. * * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. @@ -1137,11 +1147,14 @@ class Parser { // TODO(@jroesch): attributes should never be null, they should always be empty. if (raw_attrs.size()) { - // Promote "result_virtual_device" to first-class - if (raw_attrs.count("result_virtual_device")) { - ObjectRef vid = raw_attrs.at("result_virtual_device"); - // TODO(@electriclilies): check that this is a virtaul device - raw_attrs.erase("result_virtual_device"); + // Promote kVirtualDevice to first-class + String vid_key = kVirtualDevice; + if (raw_attrs.count(vid_key)) { + ObjectRef vid = raw_attrs.at(kVirtualDevice); + ICHECK(vid.as()) + << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " + << vid->GetTypeKey(); + raw_attrs.erase(kVirtualDevice); Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); func->virtual_device_ = vid; return func; From 145ad0437d5e4f402c87966583ed3219b4bef906 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 15 Feb 2022 13:20:32 -0800 Subject: [PATCH 5/9] move kVirtualDevice --- src/parser/parser.cc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 186f133f3a6c..925edcb68baa 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -230,6 +230,13 @@ GlobalTypeVar AddOrGet(InternTable* table, const std::string& nam } } +/*! brief The attribute key for the virtual device. This key will be promoted to first class on + * functions. + * + * Type: VirtualDevice + */ +constexpr const char* kVirtualDevice = "result_virtual_device"; + /*! \brief The parser class is the main interface to the parser. * the parser is not currently exposed beyond this .cc file. * @@ -1080,16 +1087,6 @@ class Parser { } } - /*! brief The attribute key for the virtual device. This key will be promoted to first class on - * functions and variable bindings. - * - * Type: VirtualDevice - */ - // why can't this be constexpr? - // also where to put me? - // also change to just virtual_device, no result - const char* kVirtualDevice = "result_virtual_device"; - /*! Parse a function definition without a leading keyword or identifier. * * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. From ecf0ee17b9be3918330a6f8e1141d7a2f44b56a1 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 15 Feb 2022 13:37:49 -0800 Subject: [PATCH 6/9] Fix annotation test --- tests/python/relay/op/annotation/test_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 2352821f7bee..abc458313101 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -70,7 +70,7 @@ def test_function_on_device(): assert len(func.attrs["param_virtual_devices"]) == 2 assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA - assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA + assert func.virtual_device_.device_type_int == 2 # ie KDLCUDA if __name__ == "__main__": From 1a331436a1c93c309b7a5a79bd89674af202c681 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 16 Feb 2022 13:11:07 -0800 Subject: [PATCH 7/9] Progress on parsing & printing --- include/tvm/target/virtual_device.h | 8 ++++++++ src/parser/parser.cc | 11 ++--------- src/printer/relay_text_printer.cc | 7 ++++++- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 07011ea412a3..4a40777563af 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -162,6 +162,7 @@ using MemoryScope = String; * * These operations are needed during device planning. */ + class VirtualDeviceNode : public AttrsNode { private: /*! @@ -361,6 +362,13 @@ class VirtualDeviceCache { std::unordered_set cache_; }; +/*! brief The attribute key for the virtual device. This key will be promoted to first class on + * functions. For use in the parser and printer only. + * + * Type: VirtualDevice + */ +constexpr const char* kVirtualDevice = "result_virtual_device"; + } // namespace tvm #endif // TVM_TARGET_VIRTUAL_DEVICE_H_ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 925edcb68baa..b80fc3277697 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include @@ -230,13 +231,6 @@ GlobalTypeVar AddOrGet(InternTable* table, const std::string& nam } } -/*! brief The attribute key for the virtual device. This key will be promoted to first class on - * functions. - * - * Type: VirtualDevice - */ -constexpr const char* kVirtualDevice = "result_virtual_device"; - /*! \brief The parser class is the main interface to the parser. * the parser is not currently exposed beyond this .cc file. * @@ -1145,8 +1139,7 @@ class Parser { // TODO(@jroesch): attributes should never be null, they should always be empty. if (raw_attrs.size()) { // Promote kVirtualDevice to first-class - String vid_key = kVirtualDevice; - if (raw_attrs.count(vid_key)) { + if (raw_attrs.count(kVirtualDevice)) { ObjectRef vid = raw_attrs.at(kVirtualDevice); ICHECK(vid.as()) << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 8dbfb8ce01f4..af538ab72090 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -445,11 +445,16 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } + if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) { + Doc vid_doc; + vid_doc << kVirtualDevice << "=" << Print(fn->virtual_device()); + params.push_back(vid_doc); + } doc << Doc::Concat(params) << ") "; if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } - doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n"; + doc << PrintBody(fn->body); return doc; } From c8a920ee65bf03500f2d5b57c8ed94f3c1cbbdab Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 16 Feb 2022 15:16:55 -0800 Subject: [PATCH 8/9] Fix printing of virtual device attribute --- src/printer/relay_text_printer.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index af538ab72090..0ef45d878393 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -447,7 +447,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { } if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) { Doc vid_doc; - vid_doc << kVirtualDevice << "=" << Print(fn->virtual_device()); + vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device()); params.push_back(vid_doc); } doc << Doc::Concat(params) << ") "; @@ -521,11 +521,7 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { for (const Expr& arg : op->args) { args.push_back(Print(arg)); } -#if TVM_LOG_DEBUG - for (const Type& type_arg : op->type_args) { - args.push_back(Print(type_arg)); - } -#endif + for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { args.push_back(d); } From 47ab243ffcf72214c63f90a2731437dcdce91a8e Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 18 Feb 2022 11:27:34 -0800 Subject: [PATCH 9/9] flake