diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 3f2f277d2926..b7d5523ca83e 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -69,7 +69,7 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel -if not _RUNTIME_ONLY and support.libinfo().get("USE_MICRO", "OFF") == "ON": +if not _RUNTIME_ONLY and support.check_micro_support(): from . import micro # NOTE: This file should be python2 compatible so we can diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 84375b761664..66d5978234b7 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -36,6 +36,7 @@ import tvm._ffi import tvm.ir.transform +import tvm.support from tvm import nd from tvm import rpc as _rpc from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope @@ -537,11 +538,10 @@ def __call__(self, measure_input, tmp_dir, **kwargs): # TODO(tvm-team) consider linline _build_func_common func, arg_info = _build_func_common(measure_input, self.runtime, **kwargs) if self.build_func.output_format == ".model-library-format": + tvm.support.check_micro_support(raise_error=True) # Late import to preserve autoTVM with USE_MICRO OFF - try: - from tvm import micro # pylint: disable=import-outside-toplevel - except ImportError: - raise ImportError("Requires USE_MICRO") + from tvm import micro # pylint: disable=import-outside-toplevel + micro.export_model_library_format(func, filename) else: func.export_library(filename, self.build_func) diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 70747cbb2d74..a79857b90dd1 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -19,7 +19,11 @@ TVMC - TVM driver command-line interface """ -from . import micro +from tvm.support import check_micro_support + +# pylint: disable=wrong-import-position +if check_micro_support(): + from . import micro from . import runner from . import autotuner from . import compiler diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 48bb052124ee..6d4df55eedc8 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -56,11 +56,7 @@ from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule from tvm.runtime.module import BenchmarkResult -try: - from tvm.micro import export_model_library_format -except ImportError: - export_model_library_format = None - +from . import check_micro_support from .common import TVMCException @@ -286,10 +282,11 @@ def export_package( executor_factory, package_path, cross, cross_options, output_format ) elif output_format == "mlf": - if export_model_library_format: - package_path = export_model_library_format(executor_factory, package_path) - else: - raise Exception("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") + check_micro_support(raise_error=True) + # pylint: disable=import-outside-toplevel + from tvm.micro import export_model_library_format + + package_path = export_model_library_format(executor_factory, package_path) return package_path diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index eb571143e551..00d0d238ac7d 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -34,10 +34,7 @@ from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor from tvm.relay.param_dict import load_param_dict -import tvm.micro.project as project -from tvm.micro.project import TemplateProjectError -from tvm.micro.project_api.client import ProjectAPIServerNotFoundError -from . import common +from . import common, check_micro_support from .common import ( TVMCException, TVMCSuppressedArgumentParser, @@ -135,6 +132,13 @@ def add_run_parser(subparsers, main_parser): # No need to augment the parser for micro targets. return + check_micro_support(raise_error=True) + + # pylint: disable=import-outside-toplevel + import tvm.micro.project as project + from tvm.micro.project import TemplateProjectError + from tvm.micro.project_api.client import ProjectAPIServerNotFoundError + project_dir = known_args.PATH try: @@ -491,6 +495,9 @@ def run_module( # This is guaranteed to work since project_dir was already checked when # building the dynamic parser to accommodate the project options, so no # checks are in place when calling GeneratedProject. + check_micro_support(raise_error=True) + import tvm.micro.project as project # pylint: disable=import-outside-toplevel + project_ = project.GeneratedProject.from_directory(project_dir, options) else: if tvmc_package.type == "mlf": @@ -512,6 +519,7 @@ def run_module( elif device == "micro": # Remote RPC (running on a micro target) logger.debug("Running on remote RPC (micro target).") + check_micro_support(raise_error=True) try: session = tvm.micro.Session(project_.transport()) stack.enter_context(session) diff --git a/python/tvm/support.py b/python/tvm/support.py index 1adbee09c52c..342b9c18cc31 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -39,6 +39,17 @@ def libinfo(): return dict(lib_info.items()) +USE_MICRO = libinfo().get("USE_MICRO", "OFF") == "ON" + + +def check_micro_support(raise_error=False): + if USE_MICRO: + return True + if raise_error: + raise Exception("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") + return False + + class FrontendTestModule(Module): """A tvm.runtime.Module whose member functions are PackedFunc."""