Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def build(

annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)

rt_mod_host = _driver_ffi.preprocess_module(annotated_mods, target_host)
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)

annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)

Expand Down
81 changes: 16 additions & 65 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,69 +442,8 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target&
return {host_mod, device_mod};
}

runtime::Module PreProcessModuleForBuild(const Map<Target, IRModule>& inputs_arg,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would almost prefer we use this name or maybe come up with an even better one from build especially in internal APIs since it is not very clear what build means. We have to preserve for historical API compatibility for now but seems better to not keep in C++ if we can.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, since build is used pretty frequently elsewhere, and could be confusing to users. Since most places use "build" to refer to generating something that is executable, I thought it would be appropriate here. But I think I agree, since in most places "build" has an entire sequence of steps, and here there's only the one.

I'm not a big fan of the name PreProcessModuleForBuild, because I read "PreProcess" as implying that this is a pretty early step in the build pipeline, rather than being nearly the last step.

The best name coming to mind is CodeGenerationForLowLevelTIR or GenerateCodeForLowLevelTIR, but waiting for the caffeine to hit may result in better names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it with fresh post-weekend eyes, it looks like this function is already part of the external API, as one of the overloads for tvm::build.

As a proposal for naming conventions, I've made an internal tvm::TIRToRuntime, with the name chosen to mimic the existing target-specific codegen hook. All internal usage calls tvm::TIRToRuntime directly, while the tvm::build overload is maintained for the external API.

const Target& host_target) {
std::vector<runtime::Module> device_modules;
Map<Target, IRModule> inputs = inputs_arg;
Target target_host = host_target;

CheckAndUpdateHostConsistency(&inputs, &target_host);

if (!target_host.defined()) {
for (const auto& it : inputs) {
if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) {
target_host = it.first;
break;
}
}
}

if (!target_host.defined()) {
target_host = DefaultTargetHost(target_host);
}

// Update target host for all targets
CheckAndUpdateHostConsistency(&inputs, &target_host);

// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);
ICHECK(mhost_all.defined()) << "The host module must be defined";

for (const auto& it : inputs) {
if (it.second.defined()) {
auto pair = SplitMixedModule(it.second, it.first, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

ICHECK(host_mod.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";

mhost_all->Update(host_mod);

if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
}
}
}

runtime::Module complete_mod = codegen::Build(mhost_all, target_host);
for (const auto& it : device_modules) {
if (it.operator->()) {
complete_mod.Import(it);
}
}
return complete_mod;
}

TVM_REGISTER_GLOBAL("driver.preprocess_module")
.set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
return PreProcessModuleForBuild(inputs_arg, host_target);
});

runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
auto pass_ctx = transform::PassContext::Current();

std::vector<runtime::Module> device_modules;
Expand Down Expand Up @@ -577,6 +516,18 @@ runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& tar
return mhost;
}

TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
.set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
return TIRToRuntime(inputs_arg, host_target);
});

// Build for heterogeneous execution when targets are specified as
// objects. This wrapper around the internal API is maintained for
// backwards compatibility.
runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) {
return TIRToRuntime(input, target_host);
}

// Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target& target_host_arg) {
Map<Target, IRModule> updated_inputs;
Expand All @@ -590,7 +541,7 @@ runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target& tar
}
updated_inputs.Set(target, it.second);
}
return build(updated_inputs, target_host);
return TIRToRuntime(updated_inputs, target_host);
}

// Build for homogeneous execution.
Expand All @@ -600,7 +551,7 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
CheckAndUpdateHostConsistency(&target, &target_host);
// More maps of target and target host
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host);
return TIRToRuntime(inputs, target_host);
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
Expand Down
48 changes: 48 additions & 0 deletions src/driver/internal_driver_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/driver/driver_api.h
* \brief Internal compiler driver APIs to drive the compilation.
*
* This module provides functionality that may be called internally
* within TVM, but is not part of the public-facing API.
*/
#ifndef TVM_DRIVER_INTERNAL_DRIVER_API_H_
#define TVM_DRIVER_INTERNAL_DRIVER_API_H_

#include <tvm/ir/module.h>
#include <tvm/target/target.h>

namespace tvm {

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to IRModule. This function is used
* for heterogeneous build.
* \param input The map contains target to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \return The built module that contains code for different processors.
*/
runtime::Module TIRToRuntime(const Map<Target, IRModule>& input, const Target& target_host);

} // namespace tvm

#endif // TVM_DRIVER_INTERNAL_DRIVER_API_H_
3 changes: 2 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <memory>

#include "../../driver/internal_driver_api.h"
#include "../../target/func_registry_generator.h"
#include "../../target/metadata_module.h"
#include "../../target/source/codegen_source_base.h"
Expand Down Expand Up @@ -452,7 +453,7 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
}
} else {
ret_.mod = tvm::build(lowered_funcs, host_target);
ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target);
}

auto ext_mods = executor_codegen_->GetExternalModules();
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <tuple>
#include <vector>

#include "../../../driver/internal_driver_api.h"
#include "../../../target/metadata_module.h"
#include "../../../target/source/codegen_source_base.h"
#include "../../op/annotation/annotation.h"
Expand Down Expand Up @@ -1158,7 +1159,7 @@ void VMCompiler::Codegen() {
LOG(INFO) << "All lowered functions have been build by BYOC -- generating an empty TVM module";
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
} else {
lib = tvm::build(per_tvm_target_modules, config_->host_target);
lib = tvm::TIRToRuntime(per_tvm_target_modules, config_->host_target);
}

lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target,
Expand Down