Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 19 additions & 43 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,42 +372,8 @@ class LowerTensorExpr : public ExprMutator {
"in the memory planner.";

auto& device_context = this->device_context_map_[expr];
auto call_dev_type = device_context.device_type;

target = GetTargetFromInteger(device_context.device_type, targets_);
// Non-External Relay Function
if (targets_.size() == 1) {
// The homogeneous execution case, we should only have one target
// so we just grab it.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// The heterogeneous execution case we have multiple targets
// in this case.
//
// We need to identify the target and translate.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
call_dev_type = kDLCPU;
} else {
call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
}

if (targets_.count(call_dev_type) == 0) {
std::stringstream msg;
msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
msg << call_dev_name << " mapped to device type (" << call_dev_type
<< ") which was not found in the target map.\n";
msg << "Availible targets: \n";
for (auto target : targets_) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}

target = targets_[call_dev_type];
}

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);

Expand Down Expand Up @@ -465,19 +431,29 @@ class LowerTensorExpr : public ExprMutator {
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
if (targets.size() == 1) {
// homogeneous execution.
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
return (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
if (dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(dev_type);
// The heterogeneous execution case, return the target associated with the
// given device type.
// If "dev_type" equals to 0, the device name only can be got from
// "targets", and it may not be "llvm", so here just set it to "unknown".
std::string dev_name = "unknown";
if (dev_type != 0) {
dev_name = runtime::DeviceName(dev_type);
}

if (targets.count(dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
std::stringstream msg;
msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
<< dev_name << " mapped to device type (" << dev_type
<< ") which was not found in the target map.\n"
<< "Availible targets: \n";
for (auto target : targets) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}
return targets[dev_type];
}
Expand Down