diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 34823ebb1781..faa246e34f0d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -275,7 +275,7 @@ def build( annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) - rt_mod_host = _driver_ffi.preprocess_module(annotated_mods, target_host) + rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 54126aaa5119..7d2e8af81cbd 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -442,69 +442,8 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } -runtime::Module PreProcessModuleForBuild(const Map& inputs_arg, - const Target& host_target) { - std::vector device_modules; - Map inputs = inputs_arg; - Target target_host = host_target; - - CheckAndUpdateHostConsistency(&inputs, &target_host); - - if (!target_host.defined()) { - for (const auto& it : inputs) { - if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { - target_host = it.first; - break; - } - } - } - - if (!target_host.defined()) { - target_host = DefaultTargetHost(target_host); - } - - // Update target host for all targets - CheckAndUpdateHostConsistency(&inputs, &target_host); - - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - auto pair = SplitMixedModule(it.second, it.first, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - mhost_all->Update(host_mod); - - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); - } - } - } - - runtime::Module complete_mod = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - complete_mod.Import(it); - } - } - return complete_mod; -} - -TVM_REGISTER_GLOBAL("driver.preprocess_module") - .set_body_typed([](const Map& inputs_arg, Target host_target) { - return PreProcessModuleForBuild(inputs_arg, host_target); - }); - -runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { +runtime::Module TIRToRuntime(const Map& inputs_arg, + const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); std::vector device_modules; @@ -577,6 +516,18 @@ runtime::Module build(const Map& inputs_arg, const Target& tar return mhost; } +TVM_REGISTER_GLOBAL("driver.tir_to_runtime") + .set_body_typed([](const Map& inputs_arg, Target host_target) { + return TIRToRuntime(inputs_arg, host_target); + }); + +// Build for heterogeneous execution when targets are specified as +// objects. This wrapper around the internal API is maintained for +// backwards compatibility. +runtime::Module build(const Map& input, const Target& target_host) { + return TIRToRuntime(input, target_host); +} + // Build for heterogeneous execution when target is a string. runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { Map updated_inputs; @@ -590,7 +541,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar } updated_inputs.Set(target, it.second); } - return build(updated_inputs, target_host); + return TIRToRuntime(updated_inputs, target_host); } // Build for homogeneous execution. @@ -600,7 +551,7 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, CheckAndUpdateHostConsistency(&target, &target_host); // More maps of target and target host Map inputs = {{target, funcs}}; - return build(inputs, target_host); + return TIRToRuntime(inputs, target_host); } transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { diff --git a/src/driver/internal_driver_api.h b/src/driver/internal_driver_api.h new file mode 100644 index 000000000000..3b7cc7c7f7fa --- /dev/null +++ b/src/driver/internal_driver_api.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/driver/driver_api.h + * \brief Internal compiler driver APIs to drive the compilation. + * + * This module provides functionality that may be called internally + * within TVM, but is not part of the public-facing API. + */ +#ifndef TVM_DRIVER_INTERNAL_DRIVER_API_H_ +#define TVM_DRIVER_INTERNAL_DRIVER_API_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief Build a device and host module for a specific target from a map + * contains target to IRModule. This function is used + * for heterogeneous build. + * \param input The map contains target to an IRModule. + * \param target_host The target for building host code. To use the default, + * pass Target(). + * \return The built module that contains code for different processors. + */ +runtime::Module TIRToRuntime(const Map& input, const Target& target_host); + +} // namespace tvm + +#endif // TVM_DRIVER_INTERNAL_DRIVER_API_H_ diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c..831a0a459421 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -35,6 +35,7 @@ #include +#include "../../driver/internal_driver_api.h" #include "../../target/func_registry_generator.h" #include "../../target/metadata_module.h" #include "../../target/source/codegen_source_base.h" @@ -452,7 +453,7 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array{}); } } else { - ret_.mod = tvm::build(lowered_funcs, host_target); + ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target); } auto ext_mods = executor_codegen_->GetExternalModules(); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e94919de7f20..29a493dc0ff9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -46,6 +46,7 @@ #include #include +#include "../../../driver/internal_driver_api.h" #include "../../../target/metadata_module.h" #include "../../../target/source/codegen_source_base.h" #include "../../op/annotation/annotation.h" @@ -1158,7 +1159,7 @@ void VMCompiler::Codegen() { LOG(INFO) << "All lowered functions have been build by BYOC -- generating an empty TVM module"; lib = codegen::CSourceModuleCreate(";", "", Array{}); } else { - lib = tvm::build(per_tvm_target_modules, config_->host_target); + lib = tvm::TIRToRuntime(per_tvm_target_modules, config_->host_target); } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target,