diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 0e43abb54b93..4a00de802c61 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode { * * For expressions that have the function type, the virtual device describes where the result of * the call to the function or closure is stored (instead of where the function itself is stored). + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. For more details, see the documentation in + * src/relay/transforms/device_planner.cc. + * * The VirtualDevice's Target field describes how the body of the function should be compiled. * * Set to VirtualDevice::FullyUnconstrained by default. @@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode { /*! * \return The virtual device (VirtualDevice). * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). + * Note that for function types, the virtual device is the device where the result of a + * call to the function is stored, not where the function itself lives. + * For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where + * the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual + * device of body. + * + * See the documentation of the virtual_device_ field (above) for more details. */ VirtualDevice virtual_device() const; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 051c05dd3d01..1493544e7324 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,26 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The \p VirtualDevice which will hold each of the functions parameters. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: Array - */ -constexpr const char* kParamVirtualDevice = "param_virtual_devices"; - -/*! - * \brief The \p VirtualDevice which will hold the function result. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: VirtualDevice - */ -constexpr const char* kResultVirtualDevice = "result_virtual_device"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 2ad27a0d20b0..0019b22f1a8f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. + VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes + << std::endl; if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes << "expressions are assigned with virtual device types. Either all " diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 48e93ccf654d..f55c59d3bc04 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -144,9 +144,15 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_virtual_devices, VirtualDevice result_virtual_device) { - return WithAttrs(std::move(function), - {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}, - {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}}); + ICHECK_EQ(param_virtual_devices.size(), function->params.size()) + << "There should be one virtual device per function parameter."; + Array annotated_params; + for (size_t i = 0; i < function->params.size(); i++) { + annotated_params.push_back(WithFields(function->params[i], {}, {}, param_virtual_devices[i])); + } + auto func = WithFields(function, annotated_params, {}, {}, {}, {}, result_virtual_device); + VLOG(1) << "Annotated func: " << PrettyPrint(func); + return func; } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); @@ -166,22 +172,14 @@ Function MaybeFunctionOnDevice(Function function, Array param_vir } VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { - auto opt_virtual_device = function_node->GetAttr(tvm::attr::kResultVirtualDevice); - return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained()); + return function_node->virtual_device(); } VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamVirtualDevice); - if (!opt_array) { - // No annotation. - return VirtualDevice::FullyUnconstrained(); - } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return opt_array.value()[i]; + return function_node->params[i]->virtual_device(); } } // namespace relay