From 4b2b70d9cd2e2fff93f2fab3e3f7c06b522a2d84 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 25 Jan 2022 12:04:57 -0800 Subject: [PATCH 1/4] VStore function result virtual devices in virtual_device_ field --- include/tvm/ir/expr.h | 12 ++++++++++++ include/tvm/ir/function.h | 10 ---------- src/printer/relay_text_printer.cc | 1 + src/relay/backend/graph_plan_memory.cc | 2 ++ src/relay/op/memory/on_device.cc | 11 ++++++----- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 0e43abb54b93..4a00de802c61 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode { * * For expressions that have the function type, the virtual device describes where the result of * the call to the function or closure is stored (instead of where the function itself is stored). + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. For more details, see the documentation in + * src/relay/transforms/device_planner.cc. + * * The VirtualDevice's Target field describes how the body of the function should be compiled. * * Set to VirtualDevice::FullyUnconstrained by default. @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode { /*! * \return The virtual device (VirtualDevice). * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). + * Note that for function types, the virtual device is the device where the result of a + * call to the function is stored, not where the function itself lives. + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. + * + * See the documentation of the virtual_device_ field (above) for more details. */ VirtualDevice virtual_device() const; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 051c05dd3d01..72dc8a5c9bf9 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol"; */ constexpr const char* kParamVirtualDevice = "param_virtual_devices"; -/*! - * \brief The \p VirtualDevice which will hold the function result. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: VirtualDevice - */ -constexpr const char* kResultVirtualDevice = "result_virtual_device"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index fdc6c37e527a..8dbfb8ce01f4 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -449,6 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } + doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n"; doc << PrintBody(fn->body); return doc; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 2ad27a0d20b0..0019b22f1a8f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. + VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes + << std::endl; if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes << "expressions are assigned with virtual device types. Either all " diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 48e93ccf654d..4536725b2073 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - return WithAttrs(std::move(function), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}, - {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}}); + auto func = WithAttrs( + WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), + {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}}); + VLOG(1) << "Annotated func: " << PrettyPrint(func); + return func; } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); @@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array param_vir } VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { - auto opt_virtual_device = function_node->GetAttr(tvm::attr::kResultVirtualDevice); - return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained()); + return function_node->virtual_device(); } VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { From 9fa7f838474ba1e7cea6dae5762a962feb376810 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 26 Jan 2022 15:33:17 -0800 Subject: [PATCH 2/4] Don't print virtual device in text printer --- src/printer/relay_text_printer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 8dbfb8ce01f4..fdc6c37e527a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -449,7 +449,6 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } - doc << "Virtual Device: " << Print(fn->virtual_device()) << " \n"; doc << PrintBody(fn->body); return doc; } From 7f9723ac9c37b9ea9797bb1179b42b7b2d60481f Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 26 Jan 2022 16:05:49 -0800 Subject: [PATCH 3/4] Store parameter virtual device in function parameters --- include/tvm/ir/function.h | 10 ---------- src/relay/op/memory/on_device.cc | 11 ++++++++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 72dc8a5c9bf9..1493544e7324 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The \p VirtualDevice which will hold each of the functions parameters. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: Array - */ -constexpr const char* kParamVirtualDevice = "param_virtual_devices"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 4536725b2073..818d84128b27 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,14 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - auto func = WithAttrs( - WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}}); + ICHECK_EQ(param_virtual_devices.size(), function->params.size()) + << "There should be one virtual device per function parameter."; + Array annotated_params; + for (size_t i = 0; i < function->params.size(); i++) { + annotated_params.push_back(WithFields(function->params[i], {}, {}, param_virtual_devices[i])); + } + auto func = WithFields(function, annotated_params, {}, {}, {}, {}, + result_virtual_device); VLOG(1) << "Annotated func: " << PrettyPrint(func); return func; } From 5d9cd8715577528dab2f9d5c993b49131417b814 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 26 Jan 2022 16:12:58 -0800 Subject: [PATCH 4/4] Update GetFunctionParamVirtualDevice --- src/relay/op/memory/on_device.cc | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 818d84128b27..f55c59d3bc04 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -150,8 +150,7 @@ Function FunctionOnDevice(Function function, Array param_virtual_ for (size_t i = 0; i < function->params.size(); i++) { annotated_params.push_back(WithFields(function->params[i], {}, {}, param_virtual_devices[i])); } - auto func = WithFields(function, annotated_params, {}, {}, {}, {}, - result_virtual_device); + auto func = WithFields(function, annotated_params, {}, {}, {}, {}, result_virtual_device); VLOG(1) << "Annotated func: " << PrettyPrint(func); return func; } @@ -180,14 +179,7 @@ VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, s ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamVirtualDevice); - if (!opt_array) { - // No annotation. - return VirtualDevice::FullyUnconstrained(); - } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return opt_array.value()[i]; + return function_node->params[i]->virtual_device(); } } // namespace relay