diff --git a/docs/arch/relay_op_strategy.rst b/docs/arch/relay_op_strategy.rst index c40251d22433..dbac7c821827 100644 --- a/docs/arch/relay_op_strategy.rst +++ b/docs/arch/relay_op_strategy.rst @@ -269,14 +269,14 @@ will then be chosen. Implementations with same priority level in this case leads to an undefined behavior, and any of them might be selected. The selection policy for ops with symbolic input shapes is still work in -progess. Currently, if any input tensor has a symbolic shape, only the +progress. Currently, if any input tensor has a symbolic shape, only the implementation with highest priority level will be used for this operator. This -will be updated after the implemention finishes. +will be updated after the implementation finishes. For debug purpose, you can add the following lines before you compile the Relay model to learn which implementation is used for each operator. .. code:: python - logging.getLogger("compile_engine").setLevel(logging.INFO) - logging.getLogger("compile_engine").addHandler(logging.StreamHandler(sys.stdout)) + logging.getLogger("te_compiler").setLevel(logging.INFO) + logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout)) diff --git a/docs/reference/api/python/relay/backend.rst b/docs/reference/api/python/relay/backend.rst index ffe8a9a8ce79..e717ee10ffab 100644 --- a/docs/reference/api/python/relay/backend.rst +++ b/docs/reference/api/python/relay/backend.rst @@ -23,7 +23,7 @@ tvm.relay.backend .. automodule:: tvm.relay.backend.interpreter :members: -.. automodule:: tvm.relay.backend.compile_engine +.. automodule:: tvm.relay.backend.te_compiler :members: .. automodule:: tvm.relay.backend.graph_executor_codegen diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0eacd1a1f667..6f35e021daf8 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -58,7 +58,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): opt_level=opt_level, config={ "relay.backend.use_auto_scheduler": True, - "relay.backend.disable_compile_engine_cache": True, }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): @@ -165,7 +164,8 @@ class TracingMode: """Two modes for tracing""" EXTRACT_TASK = 0 # trace all topi calls to extract tasks - EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops + # same as EXTRACT_TASK but ignore the task without complex ops + EXTRACT_COMPLEX_TASK_ONLY = 1 PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 723e7fa77006..7299875bf28d 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -142,7 +142,7 @@ def _traverse_expr(node): params.append(free_var) call = relay.Call(node.op, params, node.attrs) mod = tvm.IRModule.from_expr(relay.Function(params, call)) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() tracing_target = _replace_device_with_tracing(tvm_target) build_thread = threading.Thread( target=relay.build, args=(mod, tracing_target, None, None) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 714dd540d3ab..4716116a1b83 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -127,12 +127,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No assert isinstance( mod, tvm.IRModule ), "only support relay Module or Function to be tuned" - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) build_thread.start() build_thread.join() - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # Clear the warning message cache in FallbackContext if isinstance(DispatchContext.current, FallbackContext): DispatchContext.current.memory = {} diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index 4fc2b63748db..d76459236515 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Backend codegen modules for relay.""" -from . import compile_engine +from . import te_compiler diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/te_compiler.py similarity index 79% rename from python/tvm/relay/backend/compile_engine.py rename to python/tvm/relay/backend/te_compiler.py index e9129db7b200..db7504915887 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=len-as-condition,no-else-return,invalid-name -"""Backend code generation engine.""" +"""TE compiler engine (replacing legacy compile_engine).""" from __future__ import absolute_import import logging -import numpy as np import tvm from tvm import te, autotvm from tvm.ir.transform import PassContext @@ -31,7 +30,7 @@ from .. import ty as _ty from . import _backend -logger = logging.getLogger("compile_engine") +logger = logging.getLogger("te_compiler") autotvm_logger = logging.getLogger("autotvm") _first_warning = True @@ -47,7 +46,7 @@ def __init__(self, outputs, implement): @tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): - """Key in the CompileEngine. + """Key in the TE Compiler. Parameters ---------- @@ -64,7 +63,7 @@ def __init__(self, source_func, target): @tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): - """Value in the CompileEngine, including usage statistics.""" + """Value in the TE Compiler, including usage statistics.""" def _get_cache_key(source_func, target): @@ -79,24 +78,6 @@ def _get_cache_key(source_func, target): return source_func -def get_shape(shape): - """Convert the shape to correct dtype and vars.""" - ret = [] - for dim in shape: - if isinstance(dim, tvm.tir.IntImm): - if libinfo()["INDEX_DEFAULT_I64"] == "ON": - ret.append(dim) - else: - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) - elif isinstance(dim, tvm.tir.Any): - ret.append(te.var("any_dim", "int32")) - else: - ret.append(dim) - return ret - - def get_valid_implementations(op, attrs, inputs, out_type, target): """Get all valid implementations from the op strategy. @@ -275,6 +256,24 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +def get_shape(shape): + """Convert the shape to correct dtype and vars.""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + if libinfo()["INDEX_DEFAULT_I64"] == "ON": + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) + elif isinstance(dim, tvm.tir.Any): + ret.append(te.var("any_dim", "int32")) + else: + ret.append(dim) + return ret + + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" @@ -322,12 +321,12 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@tvm._ffi.register_object("relay.CompileEngine") -class CompileEngine(Object): - """CompileEngine to get lowered code.""" +@tvm._ffi.register_object("relay.TECompiler") +class TECompiler(Object): + """TECompiler to get lowered code.""" def __init__(self): - raise RuntimeError("Cannot construct a CompileEngine") + raise RuntimeError("Cannot construct a TECompiler") def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. @@ -349,7 +348,7 @@ def lower(self, source_func, target=None, mod_name="default"): try: mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key, mod_name) + return _backend._TECompilerLower(self, key, mod_name) except Exception: import traceback @@ -360,10 +359,6 @@ def lower(self, source_func, target=None, mod_name="default"): msg += "--------------------------\n" raise RuntimeError(msg) - def lower_shape_func(self, source_func, target=None): - key = _get_cache_key(source_func, target) - return _backend._CompileEngineLowerShapeFunc(self, key) - def jit(self, source_func, target=None): """JIT a source_func to a tvm.runtime.PackedFunc. @@ -381,87 +376,30 @@ def jit(self, source_func, target=None): The result of jited function. """ key = _get_cache_key(source_func, target) - return _backend._CompileEngineJIT(self, key) + return _backend._TECompilerJIT(self, key) def clear(self): """clear the existing cached functions""" - _backend._CompileEngineClear(self) + _backend._TECompilerClear(self) def items(self): """List items in the cache. - Returns ------- item_list : List[Tuple[CCacheKey, CCacheValue]] The list of items. """ - res = _backend._CompileEngineListItems(self) - assert len(res) % 2 == 0 - return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - - def shape_func_items(self): - """List items in the shape_func_cache. - - Returns - ------- - item_list : List[Tuple[CCacheKey, CCacheValue]] - The list of shape_func_items. - """ - res = _backend._CompileEngineListShapeFuncItems(self) + res = _backend._TECompilerListItems(self) assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - def get_current_ccache_key(self): - return _backend._CompileEngineGetCurrentCCacheKey(self) - - def dump(self): - """Return a string representation of engine dump. - - Returns - ------- - dump : str - The dumped string representation - """ - items = self.items() - res = "====================================\n" - res += "CompilerEngine dump, %d items cached\n" % len(items) - for k, v in items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - shape_func_items = self.shape_func_items() - res += "%d shape_func_items cached\n" % len(shape_func_items) - for k, v in shape_func_items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - return res - def get(): - """Get the global compile engine. + """Get the global TE Compiler. Returns ------- - engine : tvm.relay.backend.CompileEngine - The compile engine. + engine : tvm.relay.backend.TECompiler + The TE Compiler. """ - return _backend._CompileEngineGlobal() + return _backend._TECompilerGlobal() diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index b9d6806306f4..50f473aea1f2 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -24,7 +24,7 @@ import tvm from tvm import relay from tvm.relay.adt import Pattern -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.expr import Expr, GlobalVar, Var from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor @@ -61,7 +61,7 @@ def __init__(self, mod, target) -> None: super().__init__() self.mod = mod self.tgt = target - self.engine = compile_engine.get() + self.tec = te_compiler.get() self.fun_no = 0 self.var_no = 0 self.var_map = {} @@ -153,7 +153,10 @@ def parse_name(self, name: str): def parse_numpy_array(self, arr): """Given a Numpy array, produces an appropriate Python array or numerical literal representing its contents.""" - parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) + + def parse_single(i): + return NameConstant(i) if isinstance(i, bool) else Num(i) + if arr.ndim == 0: return parse_single(arr.item()) if arr.ndim == 1: @@ -240,11 +243,11 @@ def create_op_call(self, op: Function, relay_args, py_args): the generated Python code.""" # compile the function and register globally - cc_key = compile_engine.CCacheKey(op, self.tgt) + cc_key = te_compiler.CCacheKey(op, self.tgt) func_hash = tvm.ir.structural_hash(op) op_name = "_lowered_op_{}".format(func_hash) if not tvm.get_global_func(op_name, allow_missing=True): - jitted = self.engine.jit(cc_key, self.tgt) + jitted = self.tec.jit(cc_key, self.tgt) tvm.register_func(op_name, jitted) def convert_input(py_input, arg_type): diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index c7c572c81110..cbe8644c885f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -90,7 +90,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/bifrost/conv2d.py b/python/tvm/topi/bifrost/conv2d.py index 3b6cca6aaea4..633f36c0e7ff 100644 --- a/python/tvm/topi/bifrost/conv2d.py +++ b/python/tvm/topi/bifrost/conv2d.py @@ -477,7 +477,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 4863a06b728d..3d05058ff52c 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -46,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv3d_alter_op.py b/python/tvm/topi/cuda/conv3d_alter_op.py index faf73e77255a..c7ec7cb21fcf 100644 --- a/python/tvm/topi/cuda/conv3d_alter_op.py +++ b/python/tvm/topi/cuda/conv3d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/intel_graphics/conv2d_alter_op.py b/python/tvm/topi/intel_graphics/conv2d_alter_op.py index 0b59a849c2c9..199d984af1e4 100644 --- a/python/tvm/topi/intel_graphics/conv2d_alter_op.py +++ b/python/tvm/topi/intel_graphics/conv2d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index f3ef55b9a30c..051914113a5b 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -531,7 +531,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 8e47dff37ce6..3f2df655a615 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -57,7 +57,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 8db84497f82d..1d64261a50d7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -35,7 +35,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ef82ed617508..7005e94c2411 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -33,7 +33,7 @@ #include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -295,8 +295,6 @@ class RelayBuildModule : public runtime::ModuleNode { executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_, mod_name); - // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. - CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc deleted file mode 100644 index 0e7af2278375..000000000000 --- a/src/relay/backend/compile_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.cc - * \brief Internal compilation engine. - */ -#include "compile_engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../runtime/meta_data.h" -#include "../transforms/pass_utils.h" -#include "te_compiler_cache.h" -#include "utils.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); - -class CompileEngineImpl : public CompileEngineNode { - public: - // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; - } - - CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); - } - - // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); - if (value->packed_func != nullptr) return value->packed_func; - auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); - value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); - return value->packed_func; - } - - CachedFunc LowerShapeFunc(const CCacheKey& key) final { - return LowerShapeFuncInternal(key)->cached_func; - } - - Array LowerExternalFunctions() { - Array ret; - std::unordered_map cached_symbol; - std::vector cached_ext_funcs; - for (const auto& it : cache_) { - auto src_func = it.first->source_func; - ICHECK(src_func.defined()); - - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - ICHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen.value(); - cached_ext_funcs.push_back(it.first); - - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false) << "\n" - << "Functions with external codegen must have the " - << tvm::attr::kGlobalSymbol << " attr set."; - - std::string sn = symbol_name.value(); - if (!cached_symbol.count(sn)) { - cached_symbol[sn] = code_gen_name; - } else { - ICHECK_NE(cached_symbol[sn], code_gen_name) - << "Found duplicated symbol: " << sn << " for: " << code_gen_name; - } - - std::string ext_name = "relay.ext." + code_gen_name; - auto pf = tvm::runtime::Registry::Get(ext_name); - ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - // No need to keep compiler attribute at this point, functions have been - // extracted for specific codegen. - src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - runtime::Module ext_mod = (*pf)(src_func); - - // todo(@zhiics, @jroesch): Should this be a user visible error? - ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name - << "even though it was requested" - "by the annotated function " - << PrettyPrint(src_func); - - ret.push_back(ext_mod); - } - } - - // No need to cache external functions as we collected them all to create - // external runtime modules. - for (const auto& it : cached_ext_funcs) { - cache_.erase(it); - } - return ret; - } - - void Clear() final { cache_.clear(); } - - // List all items in the cache. - Array ListItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - // List all items in the shape_func_cache. - Array ListShapeFuncItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : shape_func_cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - /*! - * \brief Get the cache key of the function that is being lowered currently - * \return the cache key - */ - CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - - private: - // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = cache_.find(key); - if (it != cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - if (!backend::IsCompileEngineCacheDisabled()) { - cache_[key] = value; - } - } - cur_ccache_key_ = key; - - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - auto func_name = std::string(name_node.value()); - auto target = Target("ext_dev"); - auto global_var = GlobalVar(func_name); - global_var->checked_type_ = key->source_func->checked_type(); - ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); - return value; - } - - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(mangle_fn(name), &name_map_); - }); - - // Skip lowering for device copy node. - const Expr body = (key->source_func)->body; - if (const CallNode* call_node = body.as()) { - if (call_node->attrs.as()) { - value->cached_func = cfunc; - return value; - } - } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); - } - // lower the function - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); - value->cached_func = cfunc; - - return value; - } - - // implement lowered shape func - CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = shape_func_cache_.find(key); - if (it != shape_func_cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - shape_func_cache_[key] = value; - } - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - - auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); - - value->cached_func = cached_func; - return value; - } - - /*! \brief compiler cache lock*/ - std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; - /*! \brief internal compiler cache */ - std::unordered_map cache_; - /*! \brief internal compiler cache for shape funcs */ - std::unordered_map shape_func_cache_; - /*! \brief the cache key of the function that is being lowered currently*/ - CCacheKey cur_ccache_key_; -}; - -/*! \brief The global compile engine */ -CompileEngine& CompileEngine::Global() { - // intentionally allocate raw pointer to avoid - // free during destructuion. - static CompileEngine* inst = new CompileEngine(make_object()); - return *inst; -} - -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); - -TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") - .set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); - }); - -TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { - return CompileEngine::Global(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { - self->Clear(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - return self->Lower(key, mod_name); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") - .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListItems(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListShapeFuncItems(); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->GetCurrentCCacheKey(); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h deleted file mode 100644 index 4afdc6d30485..000000000000 --- a/src/relay/backend/compile_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.h - * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. - * - * This layer represents the older design of the Relay compilation flow and is being deprecated - * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of - * Relay functions. - * - */ -#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "te_compiler_cache.h" - -namespace tvm { -namespace relay { - -using namespace tvm::relay::tec; - -/*! - * \brief Backend compilation engine for - * low level code generation. - */ -class CompileEngineNode : public Object { - public: - /*! \brief destructor */ - virtual ~CompileEngineNode() {} - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The mangling function for mangling names. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; - - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; - /*! - * \brief Just in time compile to get a PackedFunc. - * \param key The key to the cached function. - * \return The result. - */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; - /*! - * \brief Lower the shape function. - * \param key The key to the cached function. - * \return The result. - */ - virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; - /*! - * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. - */ - virtual tvm::Array LowerExternalFunctions() = 0; - - /*! \brief clear the cache. */ - virtual void Clear() = 0; - - // VisitAttrs - void VisitAttrs(AttrVisitor*) {} - - static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); -}; - -/*! \brief cache entry used in compile engine */ -class CompileEngine : public ObjectRef { - public: - CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { return static_cast(get_mutable()); } - using ContainerType = CompileEngineNode; - /*! \brief The global compile engine. */ - TVM_DLL static CompileEngine& Global(); -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ef89fd9c9c6c..a596e09907d5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -37,7 +37,7 @@ #include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" -#include "./te_compiler.h" +#include "te_compiler.h" namespace tvm { namespace relay { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 445602540dbb..a8c27a126032 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -313,6 +313,45 @@ TECompiler::TECompiler() { data_ = object; } +/*! \brief The global TE compiler */ +TECompiler& TECompiler::Global() { + static TECompiler* inst = new TECompiler(make_object()); + return *inst; +} +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { + return TECompiler::Global(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); + +TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) { + self->Clear(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower") + .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) { + return self->Lower(key, mod_name); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT") + .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) { + TECompilerImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListItems(); +}); + using AnalysisRemapping = std::unordered_map; std::tuple IsDeviceCopy(const Function& func) { diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 248fd40f98eb..e3b7d46457ad 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -127,6 +127,7 @@ class TECompiler : public ObjectRef { explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} TECompilerNode* operator->() { return static_cast(get_mutable()); } using ContainerType = TECompilerNode; + TVM_DLL static TECompiler& Global(); }; /*! @@ -193,7 +194,7 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \returns The pass which lowers primative functions to TIR + * \returns The pass which lowers primitive functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, const String& module_name, std::function process_fn); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 47ba96b2c77e..7975ef873173 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -62,7 +62,6 @@ struct LoweredOutputNode : public Object { v->Visit("outputs", &outputs); v->Visit("implementation", &implementation); } - static constexpr const char* _type_key = "relay.LoweredOutput"; TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); }; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..febb550d45c0 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -427,15 +427,6 @@ inline bool IsAutoSchedulerEnabled() { .value(); } -/*! - * \brief Return whether the compile engine cache is disabled in the pass context. - */ -inline bool IsCompileEngineCacheDisabled() { - return transform::PassContext::Current() - ->GetConfig("relay.backend.disable_compile_engine_cache", Bool(false)) - .value(); -} - /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 7a86af8aeffa..c538dac048b3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -34,7 +34,7 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" #include "pattern_utils.h" namespace tvm { @@ -126,7 +126,8 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), + [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 3cd5df613f4a..4e24434642d8 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -41,7 +41,7 @@ namespace runtime { struct TypeInfo { /*! \brief The current index. */ uint32_t index{0}; - /*! \brief Index of the parent in the type hierachy */ + /*! \brief Index of the parent in the type hierarchy */ uint32_t parent_index{0}; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. @@ -58,7 +58,7 @@ struct TypeInfo { }; /*! - * \brief Type context that manages the type hierachy information. + * \brief Type context that manages the type hierarchy information. */ class TypeContext { public: diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index f151a85ec5b1..e582874d1de2 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -184,7 +184,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti ), "Got {} Arm Compute Library partitions, expected {}".format( partition_count, acl_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, params=params) diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 46bd049402a9..5a12b0487408 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -142,7 +142,7 @@ def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): with tvm.transform.PassContext(opt_level=3): if enable_bnns: mod = partition_for_bnns(mod) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, target_host=target, params=params) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 92e8f11a2312..c5ebde4b9c61 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -149,7 +149,7 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): npu_partitions : int, optional The number of Ethos-N partitions expected. """ - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() with tvm.transform.PassContext( opt_level=3, config={"relay.ext.ethos-n.options": {"variant": get_ethosn_variant()}} ): @@ -262,7 +262,7 @@ def test_error(mod, params, err_msg): except tvm.error.TVMError as e: caught = e.args[0] finally: - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() assert caught is not None assert err_msg in caught, caught diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index e87d4f874630..578ac37da25b 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -99,7 +99,7 @@ def build_module( ), "Got {} Vitis-AI partitions, expected {}".format( partition_count, vitis_ai_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target, params=params) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 746f595a4422..276cad375357 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -33,8 +33,10 @@ import tvm from tvm import relay +from tvm import te from tvm.contrib import utils, graph_executor -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler +from tvm.relay.backend.te_compiler import TECompiler from tvm.relay.backend.utils import mangle_module_name from tvm.micro import export_model_library_format @@ -721,7 +723,6 @@ def compile_and_run( def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" - compile_engine.get().clear() with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 22583eda4a40..7669d02cd536 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -41,7 +41,7 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() @tvm.testing.uses_gpu @@ -251,7 +251,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) @pytest.mark.parametrize( diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index ca792204c835..c6eb7531f635 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -26,7 +26,7 @@ from tvm import relay, runtime from tvm.contrib import utils from tvm.relay import transform -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import get_pattern_table @@ -47,7 +47,7 @@ def check_result( return # Run the reference result - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(ref_mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) @@ -61,7 +61,7 @@ def check_result( ref_result = out.numpy() def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -71,7 +71,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref_result, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eaddd33678df..754c9d1c4a74 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1422,7 +1422,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # negative test cases # sparse indices should be ints @@ -1757,7 +1758,7 @@ def verify_func(target, dev, func, data, ref_res): tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() def test_adv_index(target, dev, executor_kind): @@ -1970,7 +1971,8 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - index = np.sort(index) # In unsorted case, need to sort the index of first occurence + # In unsorted case, need to sort the index of first occurence + index = np.sort(index) return [ uniq.astype(data.dtype), index.astype("int32"), diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 93cd6f791765..5aba6229c5e2 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm.relay.backend import te_compiler import tvm.relay.testing import tvm.relay.op as reg from tvm import relay @@ -29,7 +30,6 @@ from tvm.relay import transform from tvm.relay.testing import byoc from tvm.contrib import utils -from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.contrib.register import get_pattern_table @@ -143,7 +143,7 @@ def update_lib(lib): return lib def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -157,7 +157,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) @@ -508,7 +508,7 @@ def test_extern_dnnl_mobilenet(): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **params ) - compile_engine.get().clear() + te_compiler.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -950,7 +950,7 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **ref_params ) - compile_engine.get().clear() + te_compiler.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_relay_te_compiler.py similarity index 93% rename from tests/python/relay/test_backend_compile_engine.py rename to tests/python/relay/test_relay_te_compiler.py index 092cae01f568..f8498ae83648 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -21,6 +21,7 @@ from tvm import relay from tvm import autotvm from tvm import topi +from tvm.relay.backend import te_compiler from tvm.relay.testing import run_infer_type from tvm.relay.testing.temp_op_attr import TempOpAttr import tvm.testing @@ -98,7 +99,7 @@ def _get_impls(dshape, wshape): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.get_valid_implementations( + return relay.backend.te_compiler.get_valid_implementations( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -121,7 +122,7 @@ def _select_impl(dshape, wshape, use_autotvm=False): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.select_implementation( + return relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -161,8 +162,8 @@ def _select_impl(dshape, wshape, use_autotvm=False): assert impl.name == "conv2d_1" -def test_compile_engine(): - engine = relay.backend.compile_engine.get() +def test_te_compiler(): + tec = relay.backend.te_compiler.get() def get_func(shape): x = relay.var("x", shape=shape) @@ -173,31 +174,30 @@ def get_func(shape): mod = relay.transform.InferType()(mod) return mod["main"] - z1 = engine.lower(get_func((10,)), "llvm") - z2 = engine.lower(get_func((10,)), "llvm") - z3 = engine.lower(get_func(()), "llvm") + z1 = tec.lower(get_func((10,)), "llvm") + z2 = tec.lower(get_func((10,)), "llvm") + z3 = tec.lower(get_func(()), "llvm") assert z1.same_as(z2) assert not z3.same_as(z1) if tvm.testing.device_enabled("cuda"): - z4 = engine.lower(get_func(()), "cuda") + z4 = tec.lower(get_func(()), "cuda") assert not z3.same_as(z4) # Test JIT target for target in ["llvm"]: dev = tvm.device(target) if tvm.testing.device_enabled(target): - f = engine.jit(get_func((10,)), target) + f = tec.jit(get_func((10,)), target) x = tvm.nd.array(np.ones(10).astype("float32"), device=dev) y = tvm.nd.empty((10,), device=dev) f(x, y) tvm.testing.assert_allclose(y.numpy(), x.numpy() * 3) - engine.dump() -# Note: Once compile engine is removed, we should keep this test so that +# Note: Once the te compiler is removed, we should keep this test so that # we make sure that opt_level=0 passes are being called correctly. def test_compile_placeholder_bypass(): - engine = relay.backend.compile_engine.get() + te_compiler = relay.backend.te_compiler.get() x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) z = relay.var("z", shape=(2, 3)) @@ -264,7 +264,7 @@ def test_compile_nhwc_pack(): if __name__ == "__main__": test_get_valid_implementations() test_select_implementation() - test_compile_engine() + test_te_compiler() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index cb8968cfc880..c76d7e145ecf 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -66,7 +66,8 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - check(2 ** 16, 2 ** 16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + # i32 + i32 is not promoted to i64 even if overflow + check(2 ** 16, 2 ** 16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2 ** 16, dtype="int64"), const(2 ** 16, dtype="int64"), 32, "int64") @@ -188,7 +189,7 @@ def check(m, n, target_bits, target_dtype): def test_relay_basic(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shapex, shapey, target_bits, target_dtype): x = relay.var("x", shape=shapex) @@ -230,7 +231,7 @@ def check(shapex, shapey, target_bits, target_dtype): def test_relay_take(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shape, index, target_bits, target_dtype): x = relay.var("x", shape=shape)