From 3e4377834883c29353d070ad5a8b054394791e96 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 27 Aug 2023 17:52:42 +0000 Subject: [PATCH] [Runtime] Utils to Stringify Device There exist some basic functionality to convert Device and DLDeviceType to std::string, but they are not following the common naming convention in TVM, and thus less discoverable. This commit makes changes accordingly: - `runtime::DeviceName` to `runtime::DLDeviceType2Str` - move declaration of `operator << (std::ostream&, Device)` from `runtime/device_api.h` to `runtime/packed_func.h` --- include/tvm/runtime/data_type.h | 1 + include/tvm/runtime/device_api.h | 50 +--------------------- include/tvm/runtime/packed_func.h | 52 +++++++++++++++++++++++ include/tvm/tir/op.h | 1 + src/runtime/c_runtime_api.cc | 2 +- src/runtime/contrib/papi/papi.cc | 7 +-- src/runtime/hexagon/hexagon_device_api.cc | 2 +- src/runtime/profiling.cc | 6 +-- src/runtime/rpc/rpc_module.cc | 1 + src/runtime/vm/memory_manager.cc | 15 +++---- src/tir/transforms/lower_tvm_builtin.cc | 4 +- 11 files changed, 73 insertions(+), 68 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9fb113f56b2c..ac7e879a644d 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -339,6 +339,7 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } + throw; } inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 654018565716..cb0eb7c21f11 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -245,54 +245,6 @@ class TVM_DLL DeviceAPI { constexpr int kRPCSessMask = 128; static_assert(kRPCSessMask >= TVMDeviceExtType_End); -/*! - * \brief The name of Device API factory. - * \param type The device type. - * \return the device name. - */ -inline const char* DeviceName(int type) { - switch (type) { - case kDLCPU: - return "cpu"; - case kDLCUDA: - return "cuda"; - case kDLCUDAHost: - return "cuda_host"; - case kDLCUDAManaged: - return "cuda_managed"; - case kDLOpenCL: - return "opencl"; - case kDLSDAccel: - return "sdaccel"; - case kDLAOCL: - return "aocl"; - case kDLVulkan: - return "vulkan"; - case kDLMetal: - return "metal"; - case kDLVPI: - return "vpi"; - case kDLROCM: - return "rocm"; - case kDLROCMHost: - return "rocm_host"; - case kDLExtDev: - return "ext_dev"; - case kDLOneAPI: - return "oneapi"; - case kDLWebGPU: - return "webgpu"; - case kDLHexagon: - return "hexagon"; - case kOpenGL: - return "opengl"; - case kDLMicroDev: - return "microdev"; - default: - LOG(FATAL) << "unknown type =" << type; - } -} - /*! * \brief Return true if a Device is owned by an RPC session. */ @@ -324,7 +276,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-"; dev = tvm::runtime::RemoveRPCSessionMask(dev); } - os << tvm::runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; + os << tvm::runtime::DLDeviceType2Str(static_cast(dev.device_type)) << ":" << dev.device_id; return os; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 655325ebe190..e63e92835cc5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -418,6 +418,8 @@ class TVMArgs { */ inline const char* ArgTypeCode2Str(int type_code); +inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) + // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) @@ -1257,6 +1259,56 @@ inline const char* ArgTypeCode2Str(int type_code) { default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } + throw; +} + +/*! + * \brief The name of DLDeviceType. + * \param type The device type. + * \return the device name. + */ +inline const char* DLDeviceType2Str(int type) { + switch (type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOpenCL: + return "opencl"; + case kDLSDAccel: + return "sdaccel"; + case kDLAOCL: + return "aocl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + case kOpenGL: + return "opengl"; + case kDLMicroDev: + return "microdev"; + default: + LOG(FATAL) << "unknown type = " << type; + } + throw; } namespace detail { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3d5e589ab4a4..ce4a4d6a2845 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -949,6 +949,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return FloatImm(t, static_cast(value), span); } LOG(FATAL) << "cannot make const for type " << t; + throw; } template <> diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index d7739b7b2225..93ca8a924a98 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -114,7 +114,7 @@ class DeviceAPIManager { if (api_[type] != nullptr) return api_[type]; std::lock_guard lock(mutex_); if (api_[type] != nullptr) return api_[type]; - api_[type] = GetAPI(DeviceName(type), allow_missing); + api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); return api_[type]; } else { if (rpc_api_ != nullptr) return rpc_api_; diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index b9ba8f9984e9..4fc29f92ea6a 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -73,7 +73,7 @@ int component_for_device(Device dev) { component_name = "rocm"; break; default: - LOG(WARNING) << "PAPI does not support device " << DeviceName(dev.device_type); + LOG(WARNING) << "PAPI does not support device " << DLDeviceType2Str(dev.device_type); return -1; } int cidx = PAPI_get_component_index(component_name.c_str()); @@ -170,8 +170,9 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { default: break; } - LOG(WARNING) << "PAPI could not initialize counters for " << DeviceName(device.device_type) - << ": " << component->disabled_reason << "\n" + LOG(WARNING) << "PAPI could not initialize counters for " + << DLDeviceType2Str(device.device_type) << ": " << component->disabled_reason + << "\n" << help_message; continue; } diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 27e4eb29cc7a..65162e7cc6d0 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -81,7 +81,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap // until the AoT executor's multi-device dispatch code is mature. --cconvey 2022-08-26 CHECK(dev.device_type == kDLHexagon) << "dev.device_type: " << dev.device_type << " DeviceName(" << dev.device_type - << "): " << DeviceName(dev.device_type) << ""; + << "): " << DLDeviceType2Str(dev.device_type) << ""; CHECK(ndim >= 0 && ndim <= 2) << "Hexagon Device API supports only 1d and 2d allocations, but received ndim = " << ndim; diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 2300c1a4e7c8..6a42d840b206 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -93,13 +93,13 @@ std::set seen_devices; std::mutex seen_devices_lock; Timer Timer::Start(Device dev) { - auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(dev.device_type)); + auto f = Registry::Get(std::string("profiling.timer.") + DLDeviceType2Str(dev.device_type)); if (f == nullptr) { { std::lock_guard lock(seen_devices_lock); if (seen_devices.find(dev.device_type) == seen_devices.end()) { LOG(WARNING) - << "No timer implementation for " << DeviceName(dev.device_type) + << "No timer implementation for " << DLDeviceType2Str(dev.device_type) << ", using default timer instead. It may be inaccurate or have extra overhead."; seen_devices.insert(dev.device_type); } @@ -652,7 +652,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } std::string DeviceString(Device dev) { - return DeviceName(dev.device_type) + std::to_string(dev.device_id); + return DLDeviceType2Str(dev.device_type) + std::to_string(dev.device_id); } Report Profiler::Report() { diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index d8ee2d4c769b..94f6720ca8da 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -193,6 +193,7 @@ class RPCModuleNode final : public ModuleNode { String GetSource(const String& format) final { LOG(FATAL) << "GetSource for rpc Module is not supported"; + throw; } PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 2855722a4cf4..cb52a4a4436c 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -119,14 +119,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { std::unique_ptr alloc; switch (type) { case kNaive: { - VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id - << ")"; + VLOG(1) << "New naive allocator for " << dev; alloc.reset(new NaiveAllocator(dev)); break; } case kPooled: { - VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "(" - << dev.device_id << ")"; + VLOG(1) << "New pooled allocator for " << dev; alloc.reset(new PooledAllocator(dev)); break; } @@ -139,9 +137,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { } auto alloc = m->allocators_.at(dev).get(); if (alloc->type() != type) { - LOG(WARNING) << "The type of existing allocator for " << DeviceName(dev.device_type) << "(" - << dev.device_id << ") is different from the request type (" << alloc->type() - << " vs " << type << ")"; + LOG(WARNING) << "The type of existing allocator for " << dev + << " is different from the request type (" << alloc->type() << " vs " << type + << ")"; } return alloc; } @@ -151,8 +149,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) { std::lock_guard lock(m->mu_); auto it = m->allocators_.find(dev); if (it == m->allocators_.end()) { - LOG(FATAL) << "Allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id - << ") has not been created yet."; + LOG(FATAL) << "Allocator for " << dev << " has not been created yet."; } return it->second.get(); } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 2868af0b07cf..5afb51934003 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -352,7 +352,7 @@ class BuiltinLower : public StmtExprMutator { << "but was instead the expression " << device_type_ << " with type " << device_type_.value()->GetTypeKey(); - String device_name = runtime::DeviceName(as_int->value); + String device_name = runtime::DLDeviceType2Str(as_int->value); return StringImm("device_api." + device_name + "." + method_name); } @@ -595,7 +595,7 @@ class BuiltinLower : public StmtExprMutator { let->var->type_annotation.as()->element_type.as()->dtype; std::string fdevapi_prefix = "device_api."; - fdevapi_prefix += runtime::DeviceName(device_type_.as()->value); + fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as()->value); Array args = { GetDeviceMethodName("alloc_nd"),