Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
Expand Down
50 changes: 1 addition & 49 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,54 +245,6 @@ class TVM_DLL DeviceAPI {
constexpr int kRPCSessMask = 128;
static_assert(kRPCSessMask >= TVMDeviceExtType_End);

/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type =" << type;
}
}

/*!
* \brief Return true if a Device is owned by an RPC session.
*/
Expand Down Expand Up @@ -324,7 +276,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
dev = tvm::runtime::RemoveRPCSessionMask(dev);
}
os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
os << tvm::runtime::DLDeviceType2Str(static_cast<int>(dev.device_type)) << ":" << dev.device_id;
return os;
}

Expand Down
52 changes: 52 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TVMArgs {
*/
inline const char* ArgTypeCode2Str(int type_code);

inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*)

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
Expand Down Expand Up @@ -1257,6 +1259,56 @@ inline const char* ArgTypeCode2Str(int type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

/*!
* \brief The name of DLDeviceType.
* \param type The device type.
* \return the device name.
*/
inline const char* DLDeviceType2Str(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type = " << type;
}
throw;
}

namespace detail {
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return FloatImm(t, static_cast<double>(value), span);
}
LOG(FATAL) << "cannot make const for type " << t;
throw;
}

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/contrib/papi/papi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int component_for_device(Device dev) {
component_name = "rocm";
break;
default:
LOG(WARNING) << "PAPI does not support device " << DeviceName(dev.device_type);
LOG(WARNING) << "PAPI does not support device " << DLDeviceType2Str(dev.device_type);
return -1;
}
int cidx = PAPI_get_component_index(component_name.c_str());
Expand Down Expand Up @@ -170,8 +170,9 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode {
default:
break;
}
LOG(WARNING) << "PAPI could not initialize counters for " << DeviceName(device.device_type)
<< ": " << component->disabled_reason << "\n"
LOG(WARNING) << "PAPI could not initialize counters for "
<< DLDeviceType2Str(device.device_type) << ": " << component->disabled_reason
<< "\n"
<< help_message;
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap
// until the AoT executor's multi-device dispatch code is mature. --cconvey 2022-08-26
CHECK(dev.device_type == kDLHexagon)
<< "dev.device_type: " << dev.device_type << " DeviceName(" << dev.device_type
<< "): " << DeviceName(dev.device_type) << "";
<< "): " << DLDeviceType2Str(dev.device_type) << "";

CHECK(ndim >= 0 && ndim <= 2)
<< "Hexagon Device API supports only 1d and 2d allocations, but received ndim = " << ndim;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/profiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ std::set<DLDeviceType> seen_devices;
std::mutex seen_devices_lock;

Timer Timer::Start(Device dev) {
auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(dev.device_type));
auto f = Registry::Get(std::string("profiling.timer.") + DLDeviceType2Str(dev.device_type));
if (f == nullptr) {
{
std::lock_guard<std::mutex> lock(seen_devices_lock);
if (seen_devices.find(dev.device_type) == seen_devices.end()) {
LOG(WARNING)
<< "No timer implementation for " << DeviceName(dev.device_type)
<< "No timer implementation for " << DLDeviceType2Str(dev.device_type)
<< ", using default timer instead. It may be inaccurate or have extra overhead.";
seen_devices.insert(dev.device_type);
}
Expand Down Expand Up @@ -652,7 +652,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con
}

std::string DeviceString(Device dev) {
return DeviceName(dev.device_type) + std::to_string(dev.device_id);
return DLDeviceType2Str(dev.device_type) + std::to_string(dev.device_id);
}

Report Profiler::Report() {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class RPCModuleNode final : public ModuleNode {

String GetSource(const String& format) final {
LOG(FATAL) << "GetSource for rpc Module is not supported";
throw;
}

PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat,
Expand Down
15 changes: 6 additions & 9 deletions src/runtime/vm/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
std::unique_ptr<Allocator> alloc;
switch (type) {
case kNaive: {
VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ")";
VLOG(1) << "New naive allocator for " << dev;
alloc.reset(new NaiveAllocator(dev));
break;
}
case kPooled: {
VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ")";
VLOG(1) << "New pooled allocator for " << dev;
alloc.reset(new PooledAllocator(dev));
break;
}
Expand All @@ -139,9 +137,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
}
auto alloc = m->allocators_.at(dev).get();
if (alloc->type() != type) {
LOG(WARNING) << "The type of existing allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ") is different from the request type (" << alloc->type()
<< " vs " << type << ")";
LOG(WARNING) << "The type of existing allocator for " << dev
<< " is different from the request type (" << alloc->type() << " vs " << type
<< ")";
}
return alloc;
}
Expand All @@ -151,8 +149,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
std::lock_guard<std::mutex> lock(m->mu_);
auto it = m->allocators_.find(dev);
if (it == m->allocators_.end()) {
LOG(FATAL) << "Allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ") has not been created yet.";
LOG(FATAL) << "Allocator for " << dev << " has not been created yet.";
}
return it->second.get();
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class BuiltinLower : public StmtExprMutator {
<< "but was instead the expression " << device_type_ << " with type "
<< device_type_.value()->GetTypeKey();

String device_name = runtime::DeviceName(as_int->value);
String device_name = runtime::DLDeviceType2Str(as_int->value);
return StringImm("device_api." + device_name + "." + method_name);
}

Expand Down Expand Up @@ -595,7 +595,7 @@ class BuiltinLower : public StmtExprMutator {
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;

std::string fdevapi_prefix = "device_api.";
fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as<IntImmNode>()->value);

Array<PrimExpr> args = {
GetDeviceMethodName("alloc_nd"),
Expand Down