diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index 2daebad4e676..b37d44eda7b3 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -26,10 +26,10 @@ Python API ndarray error ir + target intrin tensor schedule - target build function autotvm diff --git a/docs/api/python/target.rst b/docs/api/python/target.rst index a0f3569c6060..6851c04c5b6b 100644 --- a/docs/api/python/target.rst +++ b/docs/api/python/target.rst @@ -19,3 +19,4 @@ tvm.target ---------- .. automodule:: tvm.target :members: + :imported-members: diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 69c24008c10f..9327c0865689 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -46,7 +46,6 @@ from . import stmt from . import make from . import ir_pass -from . import codegen from . import schedule from . import ir_builder @@ -55,7 +54,6 @@ from . import hybrid from . import testing from . import error -from . import datatype from .api import * diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f779b516618c..10f0fec82a80 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -20,7 +20,6 @@ import json import numpy as np from .base import _LIB, check_call -from .. import _api_internal tvm_shape_index_t = ctypes.c_int64 @@ -48,6 +47,7 @@ class TVMByteArray(ctypes.Structure): _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] + class DataType(ctypes.Structure): """TVM datatype structure""" _fields_ = [("type_code", ctypes.c_uint8), @@ -89,11 +89,13 @@ def __init__(self, type_str): bits = 64 head = "" elif head.startswith("custom"): + # pylint: disable=import-outside-toplevel + import tvm.runtime._ffi_api low, high = head.find('['), head.find(']') if not low or not high or low >= high: raise ValueError("Badly formatted custom type string %s" % type_str) type_name = head[low + 1:high] - self.type_code = _api_internal._datatype_get_type_code(type_name) + self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name) head = head[high+1:] else: raise ValueError("Do not know how to handle type %s" % type_str) @@ -102,13 +104,15 @@ def __init__(self, type_str): def __repr__(self): + # pylint: disable=import-outside-toplevel if self.bits == 1 and self.lanes == 1: return "bool" if self.type_code in DataType.CODE2STR: type_name = DataType.CODE2STR[self.type_code] else: + import tvm.runtime._ffi_api type_name = "custom[%s]" % \ - _api_internal._datatype_get_type_name(self.type_code) + tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) x = "%s%d" % (type_name, self.bits) if self.lanes != 1: x += "x%d" % self.lanes @@ -168,28 +172,35 @@ def __init__(self, device_type, device_id): self.device_type = device_type self.device_id = device_id + def _GetDeviceAttr(self, device_type, device_id, attr_id): + """Internal helper function to invoke runtime.GetDeviceAttr""" + # pylint: disable=import-outside-toplevel + import tvm.runtime._ffi_api + return tvm.runtime._ffi_api.GetDeviceAttr( + device_type, device_id, attr_id) + @property def exist(self): """Whether this device exist.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 0) != 0 @property def max_threads_per_block(self): """Maximum number of threads on each block.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 1) @property def warp_size(self): """Number of threads that executes in concurrent.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 2) @property def max_shared_memory_per_block(self): """Total amount of shared memory per block in bytes.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 3) @property @@ -203,25 +214,25 @@ def compute_version(self): version : str The version string in `major.minor` format. """ - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 4) @property def device_name(self): """Return the string name of device.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 5) @property def max_clock_rate(self): """Return the max clock frequency of device.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 6) @property def multi_processor_count(self): """Return the number of compute units of device.""" - return _api_internal._GetDeviceAttr( + return self._GetDeviceAttr( self.device_type, self.device_id, 7) @property @@ -233,7 +244,7 @@ def max_thread_dimensions(self): dims: List of int The maximum length of threadIdx.x, threadIdx.y, threadIdx.z """ - return json.loads(_api_internal._GetDeviceAttr( + return json.loads(self._GetDeviceAttr( self.device_type, self.device_id, 8)) def sync(self): diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 15ed2953b28a..28a9fbba2834 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -106,7 +106,7 @@ def update(self, target, workload, cfg): def _alter_conv2d_layout(attrs, inputs, tinfo): workload = get_conv2d_workload(...) dispatch_ctx = autotvm.task.DispatchContext.current - target = tvm.target.current_target() + target = tvm.target.Target.current() config = dispatch_ctx.query(target, workload) # Get conv2d_NCHWc workload from config @@ -207,7 +207,7 @@ def _do_reg(myf): def dispatch_func(func, *args, **kwargs): """The wrapped dispatch function""" - tgt = _target.current_target() + tgt = _target.Target.current() workload = func(*args, **kwargs) cfg = DispatchContext.current.query(tgt, workload) if cfg.is_fallback and not cfg.template_key: diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 768f43884418..c2993ac27819 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -25,6 +25,8 @@ from tvm.runtime import Object, ndarray from tvm.ir import container +from tvm.target import codegen + from . import api from . import _api_internal from . import tensor @@ -32,7 +34,6 @@ from . import expr from . import ir_pass from . import stmt as _stmt -from . import codegen from . import target as _target from . import make from .stmt import LoweredFunc @@ -602,7 +603,7 @@ def build(inputs, "LoweredFunc.") if not isinstance(inputs, (dict, container.Map)): - target = _target.current_target() if target is None else target + target = _target.Target.current() if target is None else target target = target if target else "llvm" target_flist = {target: flist} else: diff --git a/python/tvm/contrib/clang.py b/python/tvm/contrib/clang.py index c8c6c57d0538..cb7bdcc1fd31 100644 --- a/python/tvm/contrib/clang.py +++ b/python/tvm/contrib/clang.py @@ -16,11 +16,10 @@ # under the License. """Util to invoke clang in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs import subprocess -from .._ffi.base import py_str -from .. import codegen +from tvm._ffi.base import py_str +import tvm.target from . import util @@ -44,8 +43,8 @@ def find_clang(required=True): matches the major llvm version that built with tvm """ cc_list = [] - if hasattr(codegen, "llvm_version_major"): - major = codegen.llvm_version_major() + major = tvm.target.codegen.llvm_version_major(allow_none=True) + if major is not None: cc_list += ["clang-%d.0" % major] cc_list += ["clang-%d" % major] cc_list += ["clang"] diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index fba57f8524d0..e5cebdd3f5dc 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -17,9 +17,11 @@ """Utility for ROCm backend""" import subprocess from os.path import join, exists + +from tvm._ffi.base import py_str +import tvm.target + from . import util -from .._ffi.base import py_str -from .. import codegen from ..api import register_func, convert def find_lld(required=True): @@ -42,8 +44,8 @@ def find_lld(required=True): matches the major llvm version that built with tvm """ lld_list = [] - if hasattr(codegen, "llvm_version_major"): - major = codegen.llvm_version_major() + major = tvm.target.codegen.llvm_version_major(allow_none=True) + if major is not None: lld_list += ["ld.lld-%d.0" % major] lld_list += ["ld.lld-%d" % major] lld_list += ["ld.lld"] diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 630c10fcf2dd..78ce2e20e1fd 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -154,8 +154,8 @@ def max_num_threads(func_id, args): _internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!") _internal_assert(args.__len__() <= 1, "At most one argument accepted!") if args.__len__() == 0: - res = _tgt.current_target().max_num_threads + res = _tgt.Target.current().max_num_threads else: _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") - res = _tgt.current_target(args[0].value).max_num_threads + res = _tgt.Target.current(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/runtime.py b/python/tvm/hybrid/runtime.py index aa00b4b80251..9f92b80444e8 100644 --- a/python/tvm/hybrid/runtime.py +++ b/python/tvm/hybrid/runtime.py @@ -107,7 +107,7 @@ def sigmoid(x): def max_num_threads(allow_none=True): """Get max number of threads for GPU targets.""" - return target.current_target(allow_none).max_num_threads + return target.Target.current(allow_none).max_num_threads HYBRID_GLOBALS = { diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 6146a7189318..04cbf9ee86c7 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -17,7 +17,7 @@ """Expression Intrinsics and math functions in TVM.""" # pylint: disable=redefined-builtin import tvm._ffi -import tvm.codegen +import tvm.target.codegen from . import make as _make from .api import convert, const @@ -189,7 +189,7 @@ def call_llvm_intrin(dtype, name, *args): call : Expr The call expression. """ - llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) + llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 557b9fd6e46d..68b2b1c97c03 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -176,7 +176,7 @@ def get_exec(self): def _update_target(self, target): """Update target.""" - target = target if target else tvm.target.current_target() + target = target if target else tvm.target.Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") tgts = {} diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 8b347883fe3a..6d9c850cb7ff 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -33,7 +33,7 @@ from .backend.vm import VMExecutor def _update_target(target): - target = target if target else _target.current_target() + target = target if target else _target.Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 22785eec6b41..ad71313fef52 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -220,13 +220,13 @@ def _shift(data, zero_point, out_dtype): def is_fast_int8_on_intel(): """ Checks whether the hardware has support for fast Int8 arithmetic operations. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} return intel_supported_arches.intersection(set(target.options)) def is_fast_int8_on_arm(): """ Checks whether the hardware has support for fast Int8 arithmetic operations. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) return '+v8.2a,+dotprod' in ' '.join(target.options) ######################## diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index 482a6f292f54..8f83bfbf0659 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -37,8 +37,8 @@ def _get_profile_runtime(mod): func = mod['main'] func = _quantize.CreateStatsCollector(func) - if tvm.target.current_target(): - target = tvm.target.current_target() + if tvm.target.Target.current(): + target = tvm.target.Target.current() ctx = tvm.context(target.target_name) else: target = 'llvm' diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index c6a621db368a..fbac767cea24 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -16,9 +16,7 @@ # under the License. #pylint: disable=unused-argument,inconsistent-return-statements """Internal module for registering attribute for annotation.""" -from __future__ import absolute_import - -from ... import target as _target +import tvm from .. import expr as _expr from .. import analysis as _analysis from ..base import register_relay_node @@ -133,7 +131,7 @@ def add_partition_generic(ref_call, new_args, ctx): @register_partition_function("add") def add_partition_function(ref_call, new_args, ctx): """Rewrite function for ewise add for partition""" - target = _target.current_target() + target = tvm.target.Target.current() if target and 'cuda' in target.keys: #TODO(wuwei/ziheng) cuda specific rules return add_partition_generic(ref_call, new_args, ctx) diff --git a/python/tvm/target.py b/python/tvm/target.py deleted file mode 100644 index e149c890e0ad..000000000000 --- a/python/tvm/target.py +++ /dev/null @@ -1,559 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Target management API of TVM. - -TVM's target string is in fomat `` [-option=value]...``. - -Note ----- -The list of options include: - -- **-device=** - - The device name. - -- **-mtriple=** or **-target** - - Specify the target triple, which is useful for cross - compilation. - -- **-mcpu=** - - Specify a specific chip in the current architecture to - generate code for. By default this is infered from the - target triple and autodetected to the current architecture. - -- **-mattr=a1,+a2,-a3,...** - - Override or control specific attributes of the target, - such as whether SIMD operations are enabled or not. The - default set of attributes is set by the current CPU. - -- **-system-lib** - - Build TVM system library module. System lib is a global module that contains - self registered functions in program startup. User can get the module using - :any:`tvm.runtime.system_lib`. - It is useful in environments where dynamic loading api like dlopen is banned. - The system lib will be available as long as the result code is linked by the program. - -We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string. -We can also use other specific function in this module to create specific targets. -""" -import warnings -import tvm._ffi - -from tvm.runtime import Object -from ._ffi.base import _LIB_NAME -from . import _api_internal - -try: - from decorator import decorate -except ImportError as err_msg: - # Allow decorator to be missing in runtime - if _LIB_NAME != "libtvm_runtime.so": - raise err_msg - -def _merge_opts(opts, new_opts): - """Helper function to merge options""" - if isinstance(new_opts, str): - new_opts = new_opts.split() - if new_opts: - opt_set = set(opts) - new_opts = [opt for opt in new_opts if opt not in opt_set] - return opts + new_opts - return opts - - -@tvm._ffi.register_object -class Target(Object): - """Target device information, use through TVM API. - - Note - ---- - Do not use class constructor, you can create target using the following functions - - - :any:`tvm.target.create` create target from string - - :any:`tvm.target.arm_cpu` create arm_cpu target - - :any:`tvm.target.cuda` create CUDA target - - :any:`tvm.target.rocm` create ROCM target - - :any:`tvm.target.mali` create Mali target - - :any:`tvm.target.intel_graphics` create Intel Graphics target - """ - def __new__(cls): - # Always override new to enable class - obj = Object.__new__(cls) - obj._keys = None - obj._options = None - obj._libs = None - return obj - - @property - def keys(self): - if not self._keys: - self._keys = [k.value for k in self.keys_array] - return self._keys - - @property - def options(self): - if not self._options: - self._options = [o.value for o in self.options_array] - return self._options - - @property - def libs(self): - if not self._libs: - self._libs = [l.value for l in self.libs_array] - return self._libs - - @property - def model(self): - for opt in self.options_array: - if opt.value.startswith('-model='): - return opt.value[7:] - return 'unknown' - - @property - def mcpu(self): - """Returns the mcpu from the target if it exists.""" - mcpu = '' - if self.options is not None: - for opt in self.options: - if 'mcpu' in opt: - mcpu = opt.split('=')[1] - return mcpu - - def __enter__(self): - _api_internal._EnterTargetScope(self) - return self - - def __exit__(self, ptype, value, trace): - _api_internal._ExitTargetScope(self) - - -@tvm._ffi.register_object -class GenericFunc(Object): - """GenericFunc node reference. This represents a generic function - that may be specialized for different targets. When this object is - called, a specialization is chosen based on the current target. - - Note - ---- - Do not construct an instance of this object, it should only ever be - used as a return value from calling into C++. - """ - def __call__(self, *args): - return _api_internal._GenericFuncCallFunc(self, *args) - - def set_default(self, func, allow_override=False): - """Set the default function to be used if no specializations match - the current target. - - Parameters - ---------- - func : function - The default function - - allow_override : bool - Whether to allow the current default to be overridden - """ - _api_internal._GenericFuncSetDefault(self, func, allow_override) - - def register(self, func, key_list, allow_override=False): - """Register a specialization for this GenericFunc. - - Parameters - ---------- - func : function - The function to be registered. - - key : str or list of str - The key to be registered. - - allow_override : bool, optional - Whether to allow existing keys to be overridden. - """ - key_list = [key_list] if isinstance(key_list, str) else key_list - _api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override) - - -def get_native_generic_func(name): - """Get a generic function from the global registry. If no - function is registered under the given name, a new generic - function is created. - - Parameters - ---------- - name : string - The name of the generic function to get - - Returns - ------- - func : GenericFunc - The generic function for the given name - """ - return _api_internal._GenericFuncGetGlobal(name) - - -def override_native_generic_func(func_name): - """Override a generic function defined in C++ - - Generic function allows registration of further functions - that can be dispatched on current target context. - If no registered dispatch is matched, the fdefault will be called. - - Parameters - ---------- - func_name : string - The name of the generic func to be overridden - - Returns - ------- - fgeneric : function - A wrapped generic function. - - Example - ------- - .. code-block:: python - - import tvm - # wrap function as target generic - @tvm.target.override_native_generic_func("my_func") - def my_func(a): - return a + 1 - # register specialization of my_func under target cuda - @my_func.register("cuda") - def my_func_cuda(a): - return a + 2 - # displays 3, because my_func is called - print(my_func(2)) - # displays 4, because my_func_cuda is called - with tvm.target.cuda(): - print(my_func(2)) - """ - generic_func_node = get_native_generic_func(func_name) - - def fdecorate(fdefault): - """Wrap a target generic function, overriding the previous - default that was set for the generic function. - - Parameters - ---------- - fdefault : function - The default function. - - Returns - ------- - fgeneric : function - A wrapped generic function. - - """ - generic_func_node.set_default(fdefault, allow_override=True) - - def register(key, func=None, override=True): - """Register function to be the dispatch function. - - Parameters - ---------- - key : str or list of str - The key to be registered. - - func : function - The function to be registered. - - override : bool, optional - Whether override existing registration. - - Returns - ------- - The register function is necessary. - """ - def _do_reg(myf): - generic_func_node.register(myf, key, override) - return myf - if func: - return _do_reg(func) - return _do_reg - - def dispatch_func(func, *args, **kwargs): - #pylint: disable=unused-argument - """The wrapped dispath function""" - if kwargs: - raise RuntimeError( - "Keyword arguments cannot be used when invoking generic_func %s" % func_name) - return generic_func_node(*args) - fresult = decorate(fdefault, dispatch_func) - fresult.fdefault = fdefault - fresult.register = register - return fresult - return fdecorate - -def generic_func(fdefault): - """Wrap a target generic function. - - Generic function allows registration of further functions - that can be dispatched on current target context. - If no registered dispatch is matched, the fdefault will be called. - - Parameters - ---------- - fdefault : function - The default function. - - Returns - ------- - fgeneric : function - A wrapped generic function. - - Example - ------- - .. code-block:: python - - import tvm - # wrap function as target generic - @tvm.target.generic_func - def my_func(a): - return a + 1 - # register specialization of my_func under target cuda - @my_func.register("cuda") - def my_func_cuda(a): - return a + 2 - # displays 3, because my_func is called - print(my_func(2)) - # displays 4, because my_func_cuda is called - with tvm.target.cuda(): - print(my_func(2)) - """ - dispatch_dict = {} - func_name = fdefault.__name__ - - def register(key, func=None, override=False): - """Register function to be the dispatch function. - - Parameters - ---------- - key : str or list of str - The key to be registered. - - func : function - The function to be registered. - - override : bool - Whether override existing registration. - - Returns - ------- - The register function is necessary. - """ - def _do_reg(myf): - key_list = [key] if isinstance(key, str) else key - for k in key_list: - if k in dispatch_dict and not override: - raise ValueError( - "Key is already registered for %s" % func_name) - dispatch_dict[k] = myf - return myf - if func: - return _do_reg(func) - return _do_reg - - def dispatch_func(func, *args, **kwargs): - """The wrapped dispath function""" - target = current_target() - if target is None: - return func(*args, **kwargs) - for k in target.keys: - if k in dispatch_dict: - return dispatch_dict[k](*args, **kwargs) - return func(*args, **kwargs) - fdecorate = decorate(fdefault, dispatch_func) - fdecorate.register = register - fdecorate.fdefault = fdefault - return fdecorate - - -def cuda(model='unknown', options=None): - """Returns a cuda target. - - Parameters - ---------- - model: str - The model of cuda device (e.g. 1080ti) - options : str or list of str - Additional options - """ - opts = _merge_opts(['-model=%s' % model], options) - return _api_internal._TargetCreate("cuda", *opts) - - -def rocm(model='unknown', options=None): - """Returns a ROCM target. - - Parameters - ---------- - model: str - The model of this device - options : str or list of str - Additional options - """ - opts = _merge_opts(["-model=%s" % model], options) - return _api_internal._TargetCreate("rocm", *opts) - - -def mali(model='unknown', options=None): - """Returns a ARM Mali GPU target. - - Parameters - ---------- - model: str - The model of this device - options : str or list of str - Additional options - """ - opts = ["-device=mali", '-model=%s' % model] - opts = _merge_opts(opts, options) - return _api_internal._TargetCreate("opencl", *opts) - - -def intel_graphics(model='unknown', options=None): - """Returns an Intel Graphics target. - - Parameters - ---------- - model: str - The model of this device - options : str or list of str - Additional options - """ - opts = ["-device=intel_graphics", '-model=%s' % model] - opts = _merge_opts(opts, options) - return _api_internal._TargetCreate("opencl", *opts) - - -def opengl(model='unknown', options=None): - """Returns a OpenGL target. - - Parameters - ---------- - options : str or list of str - Additional options - """ - opts = _merge_opts(["-model=%s" % model], options) - return _api_internal._TargetCreate("opengl", *opts) - - -def arm_cpu(model='unknown', options=None): - """Returns a ARM CPU target. - This function will also download pre-tuned op parameters when there is none. - - Parameters - ---------- - model: str - SoC name or phone name of the arm board. - options : str or list of str - Additional options - """ - trans_table = { - "pixel2": ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"], - "mate10": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], - "mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], - "p20": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], - "p20pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], - "rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"], - "rasp4b": ["-model=bcm2711", "-target=arm-linux-gnueabihf -mattr=+neon"], - "rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"], - "pynq": ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"], - "ultra96": ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"], - } - pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) - - opts = ["-device=arm_cpu"] + pre_defined_opt - opts = _merge_opts(opts, options) - return _api_internal._TargetCreate("llvm", *opts) - - -def rasp(options=None): - """Return a Raspberry 3b target. - - Parameters - ---------- - options : str or list of str - Additional options - """ - warnings.warn('tvm.target.rasp() is going to be deprecated. ' - 'Please use tvm.target.arm_cpu("rasp3b")') - return arm_cpu('rasp3b', options) - - -def vta(model='unknown', options=None): - opts = ["-device=vta", '-keys=cpu', '-model=%s' % model] - opts = _merge_opts(opts, options) - ret = _api_internal._TargetCreate("ext_dev", *opts) - return ret - - -def bifrost(model='unknown', options=None): - """Return an ARM Mali GPU target (Bifrost architecture). - - Parameters - ---------- - options : str or list of str - Additional options - """ - opts = ["-device=bifrost", '-model=%s' % model] - opts = _merge_opts(opts, options) - return _api_internal._TargetCreate("opencl", *opts) - - -def create(target_str): - """Get a target given target string. - - Parameters - ---------- - target_str : str - The target string. - - Returns - ------- - target : Target - The target object - - Note - ---- - See the note on :any:`tvm.target` on target string format. - """ - if isinstance(target_str, Target): - return target_str - if not isinstance(target_str, str): - raise ValueError("target_str has to be string type") - - return _api_internal._TargetFromString(target_str) - - -def current_target(allow_none=True): - """Returns the current target. - - Parameters - ---------- - allow_none : bool - Whether allow the current target to be none - - Raises - ------ - ValueError if current target is not set. - """ - return _api_internal._GetCurrentTarget(allow_none) diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py new file mode 100644 index 000000000000..abe8436a55ba --- /dev/null +++ b/python/tvm/target/__init__.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Target description and codgen module. + +TVM's target string is in fomat `` [-option=value]...``. + +Note +---- +The list of options include: + +- **-device=** + + The device name. + +- **-mtriple=** or **-target** + + Specify the target triple, which is useful for cross + compilation. + +- **-mcpu=** + + Specify a specific chip in the current architecture to + generate code for. By default this is infered from the + target triple and autodetected to the current architecture. + +- **-mattr=a1,+a2,-a3,...** + + Override or control specific attributes of the target, + such as whether SIMD operations are enabled or not. The + default set of attributes is set by the current CPU. + +- **-system-lib** + + Build TVM system library module. System lib is a global module that contains + self registered functions in program startup. User can get the module using + :any:`tvm.runtime.system_lib`. + It is useful in environments where dynamic loading api like dlopen is banned. + The system lib will be available as long as the result code is linked by the program. + +We can use :py:func:`~tvm.target.create` to create a tvm.target.Target from the target string. +We can also use other specific function in this module to create specific targets. +""" +from .target import Target, create +from .target import cuda, rocm, mali, intel_graphics, opengl, arm_cpu, rasp, vta, bifrost +from .generic_func import GenericFunc +from .generic_func import generic_func, get_native_generic_func, override_native_generic_func +from . import datatype +from . import codegen diff --git a/python/tvm/codegen.py b/python/tvm/target/_ffi_api.py similarity index 65% rename from python/tvm/codegen.py rename to python/tvm/target/_ffi_api.py index 7dc7bea90076..3f3c4f2b8e46 100644 --- a/python/tvm/codegen.py +++ b/python/tvm/target/_ffi_api.py @@ -14,25 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Code generation related functions.""" +"""FFI APIs for tvm.target""" import tvm._ffi -def build_module(lowered_func, target): - """Build lowered_func into Module. - Parameters - ---------- - lowered_func : LoweredFunc - The lowered function - - target : str - The target module type. - - Returns - ------- - module : Module - The corressponding module. - """ - return _Build(lowered_func, target) - -tvm._ffi._init_api("tvm.codegen") +tvm._ffi._init_api("target", __name__) diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py new file mode 100644 index 000000000000..e7bedaa1bbad --- /dev/null +++ b/python/tvm/target/codegen.py @@ -0,0 +1,76 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Code generation related functions.""" +from . import _ffi_api + + +def build_module(lowered_func, target): + """Build lowered_func into Module. + + Parameters + ---------- + lowered_func : LoweredFunc + The lowered function + + target : str + The target module type. + + Returns + ------- + module : runtime.Module + The corressponding module. + """ + return _ffi_api.Build(lowered_func, target) + + +def llvm_lookup_intrinsic_id(name): + """Lookup LLVM intrinsic id by name. + + Parameters + ---------- + name : str + The name of the intrinsic. + + Returns + ------- + intrin_id : int + The intrinsic id. + """ + return _ffi_api.llvm_lookup_intrinsic_id(name) + + +def llvm_version_major(allow_none=False): + """Get the major LLVM version. + + Parameters + ---------- + allow_none : bool + Whether do we allow none. + + Returns + ------- + major : int + The major LLVM version. + """ + try: + return _ffi_api.llvm_version_major() + except AttributeError: + if allow_none: + return None + raise RuntimeError( + "LLVM version is not available, please check if you build with LLVM") diff --git a/python/tvm/datatype.py b/python/tvm/target/datatype.py similarity index 85% rename from python/tvm/datatype.py rename to python/tvm/target/datatype.py index 8a936731b8ca..a9506b3339cb 100644 --- a/python/tvm/datatype.py +++ b/python/tvm/target/datatype.py @@ -17,11 +17,9 @@ """Custom datatype functionality""" import tvm._ffi -from . import make as _make -from .api import convert -from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm -from ._ffi.runtime_ctypes import DataType -from . import _api_internal +import tvm.runtime._ffi_api +from tvm.runtime import convert, DataType +from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm def register(type_name, type_code): @@ -39,7 +37,7 @@ def register(type_name, type_code): type_code : int The type's code, which should be >= kCustomBegin """ - _api_internal._datatype_register(type_name, type_code) + tvm.runtime._ffi_api._datatype_register(type_name, type_code) def get_type_name(type_code): @@ -50,7 +48,7 @@ def get_type_name(type_code): type_code : int The type code """ - return _api_internal._datatype_get_type_name(type_code) + return tvm.runtime._ffi_api._datatype_get_type_name(type_code) def get_type_code(type_name): @@ -61,7 +59,7 @@ def get_type_code(type_name): type_name : str The type name """ - return _api_internal._datatype_get_type_code(type_name) + return tvm.runtime._ffi_api._datatype_get_type_code(type_name) def get_type_registered(type_code): @@ -72,7 +70,7 @@ def get_type_registered(type_code): type_code: int The type code """ - return _api_internal._datatype_get_type_registered(type_code) + return tvm.runtime._ffi_api._datatype_get_type_registered(type_code) def register_op(lower_func, op_name, target, type_name, src_type_name=None): @@ -137,9 +135,9 @@ def lower(op): if t.lanes > 1: dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): - return _make.Call(dtype, extern_func_name, convert([op.value]), - _Call.Extern, None, 0) - return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), - _Call.Extern, None, 0) + return _Call(dtype, extern_func_name, convert([op.value]), + _Call.Extern, None, 0) + return _Call(dtype, extern_func_name, convert([op.a, op.b]), + _Call.Extern, None, 0) return lower diff --git a/python/tvm/target/generic_func.py b/python/tvm/target/generic_func.py new file mode 100644 index 000000000000..862fbedee0a4 --- /dev/null +++ b/python/tvm/target/generic_func.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Generic function.""" + +import tvm._ffi + +try: + from decorator import decorate +except ImportError as err_msg: + # Allow decorator to be missing in runtime + if _LIB_NAME != "libtvm_runtime.so": + raise err_msg + +from tvm.runtime import Object +from . target import Target +from . import _ffi_api + + +@tvm._ffi.register_object +class GenericFunc(Object): + """GenericFunc node reference. This represents a generic function + that may be specialized for different targets. When this object is + called, a specialization is chosen based on the current target. + + Note + ---- + Do not construct an instance of this object, it should only ever be + used as a return value from calling into C++. + """ + def __call__(self, *args): + return _ffi_api.GenericFuncCallFunc(self, *args) + + def set_default(self, func, allow_override=False): + """Set the default function to be used if no specializations match + the current target. + + Parameters + ---------- + func : function + The default function + + allow_override : bool + Whether to allow the current default to be overridden + """ + _ffi_api.GenericFuncSetDefault(self, func, allow_override) + + def register(self, func, key_list, allow_override=False): + """Register a specialization for this GenericFunc. + + Parameters + ---------- + func : function + The function to be registered. + + key : str or list of str + The key to be registered. + + allow_override : bool, optional + Whether to allow existing keys to be overridden. + """ + key_list = [key_list] if isinstance(key_list, str) else key_list + _ffi_api.GenericFuncRegisterFunc(self, func, key_list, allow_override) + + +def get_native_generic_func(name): + """Get a generic function from the global registry. If no + function is registered under the given name, a new generic + function is created. + + Parameters + ---------- + name : string + The name of the generic function to get + + Returns + ------- + func : GenericFunc + The generic function for the given name + """ + return _ffi_api.GenericFuncGetGlobal(name) + + +def override_native_generic_func(func_name): + """Override a generic function defined in C++ + + Generic function allows registration of further functions + that can be dispatched on current target context. + If no registered dispatch is matched, the fdefault will be called. + + Parameters + ---------- + func_name : string + The name of the generic func to be overridden + + Returns + ------- + fgeneric : function + A wrapped generic function. + + Example + ------- + .. code-block:: python + + import tvm + # wrap function as target generic + @tvm.target.override_native_generic_func("my_func") + def my_func(a): + return a + 1 + # register specialization of my_func under target cuda + @my_func.register("cuda") + def my_func_cuda(a): + return a + 2 + # displays 3, because my_func is called + print(my_func(2)) + # displays 4, because my_func_cuda is called + with tvm.target.cuda(): + print(my_func(2)) + """ + generic_func_node = get_native_generic_func(func_name) + + def fdecorate(fdefault): + """Wrap a target generic function, overriding the previous + default that was set for the generic function. + + Parameters + ---------- + fdefault : function + The default function. + + Returns + ------- + fgeneric : function + A wrapped generic function. + + """ + generic_func_node.set_default(fdefault, allow_override=True) + + def register(key, func=None, override=True): + """Register function to be the dispatch function. + + Parameters + ---------- + key : str or list of str + The key to be registered. + + func : function + The function to be registered. + + override : bool, optional + Whether override existing registration. + + Returns + ------- + The register function is necessary. + """ + def _do_reg(myf): + generic_func_node.register(myf, key, override) + return myf + if func: + return _do_reg(func) + return _do_reg + + def dispatch_func(func, *args, **kwargs): + #pylint: disable=unused-argument + """The wrapped dispath function""" + if kwargs: + raise RuntimeError( + "Keyword arguments cannot be used when invoking generic_func %s" % func_name) + return generic_func_node(*args) + fresult = decorate(fdefault, dispatch_func) + fresult.fdefault = fdefault + fresult.register = register + return fresult + return fdecorate + +def generic_func(fdefault): + """Wrap a target generic function. + + Generic function allows registration of further functions + that can be dispatched on current target context. + If no registered dispatch is matched, the fdefault will be called. + + Parameters + ---------- + fdefault : function + The default function. + + Returns + ------- + fgeneric : function + A wrapped generic function. + + Example + ------- + .. code-block:: python + + import tvm + # wrap function as target generic + @tvm.target.generic_func + def my_func(a): + return a + 1 + # register specialization of my_func under target cuda + @my_func.register("cuda") + def my_func_cuda(a): + return a + 2 + # displays 3, because my_func is called + print(my_func(2)) + # displays 4, because my_func_cuda is called + with tvm.target.cuda(): + print(my_func(2)) + """ + dispatch_dict = {} + func_name = fdefault.__name__ + + def register(key, func=None, override=False): + """Register function to be the dispatch function. + + Parameters + ---------- + key : str or list of str + The key to be registered. + + func : function + The function to be registered. + + override : bool + Whether override existing registration. + + Returns + ------- + The register function is necessary. + """ + def _do_reg(myf): + key_list = [key] if isinstance(key, str) else key + for k in key_list: + if k in dispatch_dict and not override: + raise ValueError( + "Key is already registered for %s" % func_name) + dispatch_dict[k] = myf + return myf + if func: + return _do_reg(func) + return _do_reg + + def dispatch_func(func, *args, **kwargs): + """The wrapped dispath function""" + target = Target.current() + if target is None: + return func(*args, **kwargs) + for k in target.keys: + if k in dispatch_dict: + return dispatch_dict[k](*args, **kwargs) + return func(*args, **kwargs) + fdecorate = decorate(fdefault, dispatch_func) + fdecorate.register = register + fdecorate.fdefault = fdefault + return fdecorate diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py new file mode 100644 index 000000000000..8405bb10720f --- /dev/null +++ b/python/tvm/target/target.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Target data structure.""" +import warnings +import tvm._ffi + +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object +class Target(Object): + """Target device information, use through TVM API. + + Note + ---- + Do not use class constructor, you can create target using the following functions + + - :py:func:`~tvm.target.create` create target from string + - :py:func:`~tvm.target.arm_cpu` create arm_cpu target + - :py:func:`~tvm.target.cuda` create CUDA target + - :py:func:`~tvm.target.rocm` create ROCM target + - :py:func:`~tvm.target.mali` create Mali target + - :py:func:`~tvm.target.intel_graphics` create Intel Graphics target + """ + def __new__(cls): + # Always override new to enable class + obj = Object.__new__(cls) + obj._keys = None + obj._options = None + obj._libs = None + return obj + + @property + def keys(self): + if not self._keys: + self._keys = [k.value for k in self.keys_array] + return self._keys + + @property + def options(self): + if not self._options: + self._options = [o.value for o in self.options_array] + return self._options + + @property + def libs(self): + if not self._libs: + self._libs = [l.value for l in self.libs_array] + return self._libs + + @property + def model(self): + for opt in self.options_array: + if opt.value.startswith('-model='): + return opt.value[7:] + return 'unknown' + + @property + def mcpu(self): + """Returns the mcpu from the target if it exists.""" + mcpu = '' + if self.options is not None: + for opt in self.options: + if 'mcpu' in opt: + mcpu = opt.split('=')[1] + return mcpu + + def __enter__(self): + _ffi_api.EnterTargetScope(self) + return self + + def __exit__(self, ptype, value, trace): + _ffi_api.ExitTargetScope(self) + + @staticmethod + def current(allow_none=True): + """Returns the current target. + + Parameters + ---------- + allow_none : bool + Whether allow the current target to be none + + Raises + ------ + ValueError if current target is not set. + """ + return _ffi_api.GetCurrentTarget(allow_none) + + +def _merge_opts(opts, new_opts): + """Helper function to merge options""" + if isinstance(new_opts, str): + new_opts = new_opts.split() + if new_opts: + opt_set = set(opts) + new_opts = [opt for opt in new_opts if opt not in opt_set] + return opts + new_opts + return opts + + +def cuda(model='unknown', options=None): + """Returns a cuda target. + + Parameters + ---------- + model: str + The model of cuda device (e.g. 1080ti) + options : str or list of str + Additional options + """ + opts = _merge_opts(['-model=%s' % model], options) + return _ffi_api.TargetCreate("cuda", *opts) + + +def rocm(model='unknown', options=None): + """Returns a ROCM target. + + Parameters + ---------- + model: str + The model of this device + options : str or list of str + Additional options + """ + opts = _merge_opts(["-model=%s" % model], options) + return _ffi_api.TargetCreate("rocm", *opts) + + +def mali(model='unknown', options=None): + """Returns a ARM Mali GPU target. + + Parameters + ---------- + model: str + The model of this device + options : str or list of str + Additional options + """ + opts = ["-device=mali", '-model=%s' % model] + opts = _merge_opts(opts, options) + return _ffi_api.TargetCreate("opencl", *opts) + + +def intel_graphics(model='unknown', options=None): + """Returns an Intel Graphics target. + + Parameters + ---------- + model: str + The model of this device + options : str or list of str + Additional options + """ + opts = ["-device=intel_graphics", '-model=%s' % model] + opts = _merge_opts(opts, options) + return _ffi_api.TargetCreate("opencl", *opts) + + +def opengl(model='unknown', options=None): + """Returns a OpenGL target. + + Parameters + ---------- + options : str or list of str + Additional options + """ + opts = _merge_opts(["-model=%s" % model], options) + return _ffi_api.TargetCreate("opengl", *opts) + + +def arm_cpu(model='unknown', options=None): + """Returns a ARM CPU target. + This function will also download pre-tuned op parameters when there is none. + + Parameters + ---------- + model: str + SoC name or phone name of the arm board. + options : str or list of str + Additional options + """ + trans_table = { + "pixel2": ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"], + "mate10": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], + "mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], + "p20": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], + "p20pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"], + "rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"], + "rasp4b": ["-model=bcm2711", "-target=arm-linux-gnueabihf -mattr=+neon"], + "rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"], + "pynq": ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"], + "ultra96": ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"], + } + pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) + + opts = ["-device=arm_cpu"] + pre_defined_opt + opts = _merge_opts(opts, options) + return _ffi_api.TargetCreate("llvm", *opts) + + +def rasp(options=None): + """Return a Raspberry 3b target. + + Parameters + ---------- + options : str or list of str + Additional options + """ + warnings.warn('tvm.target.rasp() is going to be deprecated. ' + 'Please use tvm.target.arm_cpu("rasp3b")') + return arm_cpu('rasp3b', options) + + +def vta(model='unknown', options=None): + opts = ["-device=vta", '-keys=cpu', '-model=%s' % model] + opts = _merge_opts(opts, options) + ret = _ffi_api.TargetCreate("ext_dev", *opts) + return ret + + +def bifrost(model='unknown', options=None): + """Return an ARM Mali GPU target (Bifrost architecture). + + Parameters + ---------- + options : str or list of str + Additional options + """ + opts = ["-device=bifrost", '-model=%s' % model] + opts = _merge_opts(opts, options) + return _ffi_api.TargetCreate("opencl", *opts) + + +def create(target_str): + """Get a target given target string. + + Parameters + ---------- + target_str : str + The target string. + + Returns + ------- + target : Target + The target object + + Note + ---- + See the note on :py:mod:`~tvm.target` on target string format. + """ + if isinstance(target_str, Target): + return target_str + if not isinstance(target_str, str): + raise ValueError("target_str has to be string type") + + return _ffi_api.TargetFromString(target_str) diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index c60b4a8b95b1..8af2bd00bb35 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -46,20 +46,20 @@ namespace tvm { namespace runtime { std::string GetCustomTypeName(uint8_t type_code) { - auto f = tvm::runtime::Registry::Get("_datatype_get_type_name"); - CHECK(f) << "Function _datatype_get_type_name not found"; + auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name"); + CHECK(f) << "Function runtime._datatype_get_type_name not found"; return (*f)(type_code).operator std::string(); } uint8_t GetCustomTypeCode(const std::string& type_name) { - auto f = tvm::runtime::Registry::Get("_datatype_get_type_code"); - CHECK(f) << "Function _datatype_get_type_code not found"; + auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code"); + CHECK(f) << "Function runtime._datatype_get_type_code not found"; return (*f)(type_name).operator int(); } bool GetCustomTypeRegistered(uint8_t type_code) { - auto f = tvm::runtime::Registry::Get("_datatype_get_type_registered"); - CHECK(f) << "Function _datatype_get_type_registered not found"; + auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"); + CHECK(f) << "Function runtime._datatype_get_type_registered not found"; return (*f)(type_code).operator bool(); } @@ -612,7 +612,7 @@ TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) }); // set device api -TVM_REGISTER_GLOBAL("_GetDeviceAttr") +TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") .set_body([](TVMArgs args, TVMRetValue *ret) { TVMContext ctx; ctx.device_type = static_cast(args[0].operator int()); diff --git a/src/target/codegen.cc b/src/target/codegen.cc index a9c820160cde..ee5e6a62b646 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -244,7 +244,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -TVM_REGISTER_GLOBAL("codegen._Build") +TVM_REGISTER_GLOBAL("target.Build") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index b49b395ec08a..c16182da3674 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -25,22 +25,22 @@ namespace datatype { using runtime::TVMArgs; using runtime::TVMRetValue; -TVM_REGISTER_GLOBAL("_datatype_register") +TVM_REGISTER_GLOBAL("runtime._datatype_register") .set_body([](TVMArgs args, TVMRetValue* ret) { datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); }); -TVM_REGISTER_GLOBAL("_datatype_get_type_code") +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0]); }); -TVM_REGISTER_GLOBAL("_datatype_get_type_name") +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Registry::Global()->GetTypeName(args[0].operator int()); }); -TVM_REGISTER_GLOBAL("_datatype_get_type_registered") +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); }); @@ -90,7 +90,6 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t } else { ss << runtime::TypeCode2Str(src_type_code); } - return runtime::Registry::Get(ss.str()); } diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 817d48f0cdbf..8eef4b75ff40 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -123,18 +123,18 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { func.CallPacked(args, ret); } -TVM_REGISTER_GLOBAL("_GenericFuncCreate") +TVM_REGISTER_GLOBAL("target.GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object()); }); -TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal") +TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); }); -TVM_REGISTER_GLOBAL("_GenericFuncSetDefault") +TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown @@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncSetDefault") .set_default(*func, allow_override); }); -TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") +TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown @@ -162,7 +162,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") .register_func(tags_vector, *func, allow_override); }); -TVM_REGISTER_GLOBAL("_GenericFuncCallFunc") +TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index f28bad4d63a6..30755fcfc125 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -349,11 +349,6 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } -TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = static_cast(LookupLLVMIntrinsic(args[0])); - }); - TVM_REGISTER_GLOBAL("codegen.build_llvm") .set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); @@ -361,9 +356,13 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm") *rv = runtime::Module(n); }); -TVM_REGISTER_GLOBAL("codegen.llvm_version_major") +TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(LookupLLVMIntrinsic(args[0])); + }); + +TVM_REGISTER_GLOBAL("target.llvm_version_major") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::ostringstream os; int major = TVM_LLVM_VERSION / 10; *rv = major; }); diff --git a/src/target/target.cc b/src/target/target.cc index 245425a63921..05253a5a2bc9 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -144,7 +144,7 @@ Target CreateTarget(const std::string& target_name, return Target(t); } -TVM_REGISTER_GLOBAL("_TargetCreate") +TVM_REGISTER_GLOBAL("target.TargetCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_name = args[0]; std::vector options; @@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("_TargetCreate") *ret = CreateTarget(target_name, options); }); -TVM_REGISTER_GLOBAL("_TargetFromString") +TVM_REGISTER_GLOBAL("target.TargetFromString") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; *ret = Target::Create(target_str); @@ -269,7 +269,7 @@ tvm::Target Target::Current(bool allow_not_defined) { return Target(); } -TVM_REGISTER_GLOBAL("_GetCurrentTarget") +TVM_REGISTER_GLOBAL("target.GetCurrentTarget") .set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; *ret = Target::Current(allow_not_defined); @@ -284,10 +284,10 @@ class Target::Internal { } }; -TVM_REGISTER_GLOBAL("_EnterTargetScope") +TVM_REGISTER_GLOBAL("target.EnterTargetScope") .set_body_typed(Target::Internal::EnterScope); -TVM_REGISTER_GLOBAL("_ExitTargetScope") +TVM_REGISTER_GLOBAL("target.ExitTargetScope") .set_body_typed(Target::Internal::ExitScope); namespace target { diff --git a/src/tir/pass/lower_custom_datatypes.cc b/src/tir/pass/lower_custom_datatypes.cc index 66ea5743240a..b24fdf158f4a 100644 --- a/src/tir/pass/lower_custom_datatypes.cc +++ b/src/tir/pass/lower_custom_datatypes.cc @@ -95,19 +95,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } -#define DEFINE_MUTATE__(OP, NodeName) \ - inline PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (toBeLowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr); \ - } \ - return expr; \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ } DEFINE_MUTATE__(Add, AddNode); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 02626a468aa2..c2c808fa7895 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -103,11 +103,11 @@ TEST(BuildModule, Heterogeneous) { return copy[i] - C[i]; }, "elemwise_sub"); - const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope"); - (*enter_target_scope_func)(target_cuda); + With cuda_scope(target_cuda); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); - (*enter_target_scope_func)(target_llvm); + + With llvm_scope(target_llvm); auto s2 = create_schedule({elemwise_sub->op}); auto config = BuildConfig::Create(); diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index db5214b91d1f..f95787dd94a4 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -55,7 +55,7 @@ def verify(target): if not tvm.runtime.enabled(target): print("Target %s is not enabled" % target) return - f = tvm.codegen.build_module(fapi, target) + f = tvm.target.codegen.build_module(fapi, target) # verify ctx = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index ea729618097e..5876a7052a2d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1115,7 +1115,7 @@ def _has_fast_int8_instructions(asm, target): # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] - llvm_version = tvm.codegen.llvm_version_major() + llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: dtypes = ('uint8', 'int8', 'int32') @@ -1208,7 +1208,7 @@ def test_depthwise_conv2d_int8(): parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] - llvm_version = tvm.codegen.llvm_version_major() + llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: with relay.build_config(opt_level=3): diff --git a/tests/python/unittest/test_autotvm_common.py b/tests/python/unittest/test_autotvm_common.py index 4f8758e2eaf7..7043e473ec4d 100644 --- a/tests/python/unittest/test_autotvm_common.py +++ b/tests/python/unittest/test_autotvm_common.py @@ -50,7 +50,7 @@ def matmul(N, L, M, dtype): @autotvm.template def bad_matmul(N, L, M, dtype): - if 'bad_device' in tvm.target.current_target().keys: + if 'bad_device' in tvm.target.Target.current().keys: A = tvm.placeholder((N, L), name='A', dtype=dtype) B = tvm.placeholder((L, M), name='B', dtype=dtype) diff --git a/tests/python/unittest/test_codegen_c_host.py b/tests/python/unittest/test_codegen_c_host.py index 271237b51503..a126c07c8ac1 100644 --- a/tests/python/unittest/test_codegen_c_host.py +++ b/tests/python/unittest/test_codegen_c_host.py @@ -75,7 +75,7 @@ def check_c(): f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline") fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) - mhost = tvm.codegen.build_module(fsplits[0], "c") + mhost = tvm.target.codegen.build_module(fsplits[0], "c") temp = util.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index fe416e6312d9..63ee03028e7e 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -84,8 +84,8 @@ def check_target(device, host="stackvm"): return if not tvm.runtime.enabled(host): return - mhost = tvm.codegen.build_module(fsplits[0], host) - mdev = tvm.codegen.build_module(fsplits[1:], device) + mhost = tvm.target.codegen.build_module(fsplits[0], host) + mdev = tvm.target.codegen.build_module(fsplits[1:], device) mhost.import_module(mdev) code = mdev.get_source() f = mhost.entry_func @@ -110,8 +110,8 @@ def check_module_save(device, host="stackvm"): fmt = "hsaco" else: fmt = device - mhost = tvm.codegen.build_module(fsplits[0], host) - mdev = tvm.codegen.build_module(fsplits[1:], device) + mhost = tvm.target.codegen.build_module(fsplits[0], host) + mdev = tvm.target.codegen.build_module(fsplits[1:], device) temp = util.tempdir() mpath = temp.relpath("test.%s" % fmt) mdev.save(mpath) diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index a37bc2a736e3..c60f3816722c 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -570,9 +570,9 @@ def test_dwarf_debug_information(): def check_llvm_object(): if not tvm.runtime.enabled("llvm"): return - if tvm.codegen.llvm_version_major() < 5: + if tvm.target.codegen.llvm_version_major() < 5: return - if tvm.codegen.llvm_version_major() > 6: + if tvm.target.codegen.llvm_version_major() > 6: return # build two functions f2 = tvm.lower(s, [A, B, C], name="fadd1") @@ -607,9 +607,9 @@ def check_llvm_object(): def check_llvm_ir(): if not tvm.runtime.enabled("llvm"): return - if tvm.codegen.llvm_version_major() < 5: + if tvm.target.codegen.llvm_version_major() < 5: return - if tvm.codegen.llvm_version_major() > 6: + if tvm.target.codegen.llvm_version_major() > 6: return # build two functions f2 = tvm.lower(s, [A, B, C], name="fadd1") diff --git a/tests/python/unittest/test_codegen_static_init.py b/tests/python/unittest/test_codegen_static_init.py index 80c4fa4df0e8..3bfe01319a3a 100644 --- a/tests/python/unittest/test_codegen_static_init.py +++ b/tests/python/unittest/test_codegen_static_init.py @@ -33,7 +33,7 @@ def test_static_callback(): stmt = ib.get() fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.codegen.build_module(fapi, "llvm") + f = tvm.target.codegen.build_module(fapi, "llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) f(a) @@ -57,7 +57,7 @@ def test_cb(sh, A): stmt = ib.get() fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.codegen.build_module(fapi, "llvm") + f = tvm.target.codegen.build_module(fapi, "llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index 60a948db68bb..d477983b0979 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -21,7 +21,7 @@ def run_jit(fapi, check): for target in ["llvm", "stackvm"]: if not tvm.runtime.enabled(target): continue - f = tvm.codegen.build_module(fapi, target) + f = tvm.target.codegen.build_module(fapi, target) s = f.get_source() check(f) diff --git a/tests/python/unittest/test_codegen_x86.py b/tests/python/unittest/test_codegen_x86.py index 06591a3cbf76..e17c6cf8cbcc 100644 --- a/tests/python/unittest/test_codegen_x86.py +++ b/tests/python/unittest/test_codegen_x86.py @@ -19,9 +19,9 @@ def test_fp16_to_fp32(): - if tvm.codegen.llvm_version_major() < 6: + if tvm.target.codegen.llvm_version_major() < 6: print("Skipping due to LLVM version being {} < 6".format( - tvm.codegen.llvm_version_major())) + tvm.target.codegen.llvm_version_major())) return def fp16_to_fp32(target, width, match=None, not_match=None): diff --git a/tests/python/unittest/test_custom_datatypes_mybfloat16.py b/tests/python/unittest/test_custom_datatypes_mybfloat16.py index 79c02efa6cc7..00f9b3329835 100644 --- a/tests/python/unittest/test_custom_datatypes_mybfloat16.py +++ b/tests/python/unittest/test_custom_datatypes_mybfloat16.py @@ -29,19 +29,19 @@ def setup_module(): # In this case, we have built the test functions used below right into TVM. # CDLL("libmybfloat16.so", RTLD_GLOBAL) - tvm.datatype.register("bfloat", 129) + tvm.target.datatype.register("bfloat", 129) - tvm.datatype.register_op( - tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast", + tvm.target.datatype.register_op( + tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast", "llvm", "bfloat", "float") - tvm.datatype.register_op( - tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", + tvm.target.datatype.register_op( + tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", "llvm", "float", "bfloat") - tvm.datatype.register_op( - tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", + tvm.target.datatype.register_op( + tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", "bfloat") - tvm.datatype.register_op( - tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm", + tvm.target.datatype.register_op( + tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm", "llvm", "bfloat") def lower_datatypes_and_build(schedule, args): diff --git a/tests/python/unittest/test_lang_target.py b/tests/python/unittest/test_lang_target.py index 85417d462c33..6da99f827047 100644 --- a/tests/python/unittest/test_lang_target.py +++ b/tests/python/unittest/test_lang_target.py @@ -50,7 +50,7 @@ def test_target_dispatch(): with tvm.target.create("metal"): assert mygeneric(1) == 3 - assert tvm.target.current_target() is None + assert tvm.target.Target.current() is None def test_target_string_parse(): diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 38a7b43761b5..5207b0956941 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -39,7 +39,7 @@ def test_dltensor_compatible(): stmt = ib.get() fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.codegen.build_module(fapi, "stackvm") + f = tvm.target.codegen.build_module(fapi, "stackvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) aview = MyTensorView(a) f(aview) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index e47db94c4353..b1a784bc48b6 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -57,7 +57,7 @@ def save_object(names): i + 1)) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) - m = tvm.codegen.build_module(fapi, "llvm") + m = tvm.target.codegen.build_module(fapi, "llvm") for name in names: m.save(name) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index b0d4d1361ccc..f0d650adeac1 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -588,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): idxd = tvm.indexdiv if groups == 1: - target = tvm.target.current_target() + target = tvm.target.Target.current() dispatch_ctx = autotvm.DispatchContext.current cfg = dispatch_ctx.query(target, workload) @@ -693,12 +693,12 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): else: raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key) else: - target = tvm.target.current_target() + target = tvm.target.Target.current() dispatch_ctx = autotvm.DispatchContext.current cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: # if is fallback, clear query cache and return None - autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload) + autotvm.task.clear_fallback_cache(tvm.target.Target.current(), workload) if layout == 'NHWC' and kernel_layout == 'HWOI': new_attrs['data_layout'] = 'NCHW' new_attrs['kernel_layout'] = 'OIHW' diff --git a/topi/python/topi/bifrost/conv2d.py b/topi/python/topi/bifrost/conv2d.py index 3ee231eea419..2ae65800e925 100644 --- a/topi/python/topi/bifrost/conv2d.py +++ b/topi/python/topi/bifrost/conv2d.py @@ -156,7 +156,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): # this part to make tuning records correct s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') else: - max_threads = tvm.target.current_target(allow_none=False).max_num_threads + max_threads = tvm.target.Target.current(allow_none=False).max_num_threads co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) fused, vec = s[kernel_vec].split(fused, VC) diff --git a/topi/python/topi/bifrost/transforms.py b/topi/python/topi/bifrost/transforms.py index ea3e51082657..d7fc292f0ade 100644 --- a/topi/python/topi/bifrost/transforms.py +++ b/topi/python/topi/bifrost/transforms.py @@ -24,7 +24,7 @@ def fuse_and_bind(s, tensor, axis=None, num_thread=None): """Fuse all the axis and bind to GPU threads""" axis = axis or s[tensor].op.axis fused = s[tensor].fuse(*axis) - max_threads = tvm.target.current_target(allow_none=False).max_num_threads + max_threads = tvm.target.Target.current(allow_none=False).max_num_threads bx, tx = s[tensor].split(fused, num_thread or max_threads) s[tensor].bind(bx, tvm.thread_axis("blockIdx.x")) s[tensor].bind(tx, tvm.thread_axis("threadIdx.x")) diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index 2d1b93ec0382..24fc2a17aa18 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -41,7 +41,7 @@ def batch_matmul_cuda(x, y): output : tvm.Tensor 3-D with shape [batch, M, N] """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name == "cuda" and "cublas" in target.libs: return cublas.batch_matmul(x, y, False, True) return batch_matmul_default(x, y) @@ -61,7 +61,7 @@ def schedule_batch_matmul(outs): s: Schedule The computation schedule for the op. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name == "cuda" and "cublas" in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/cuda/conv1d.py b/topi/python/topi/cuda/conv1d.py index 201921564cbf..43754a31df48 100644 --- a/topi/python/topi/cuda/conv1d.py +++ b/topi/python/topi/cuda/conv1d.py @@ -115,7 +115,7 @@ def _callback(op): cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: @@ -230,7 +230,7 @@ def _callback(op): cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py index be7824e71e81..4cedbd529f02 100644 --- a/topi/python/topi/cuda/conv1d_transpose_ncw.py +++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py @@ -116,7 +116,7 @@ def _callback(op): cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index d831ba494d9f..f26069cfc3f0 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -69,7 +69,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cudnn" in target.libs: if layout == 'NCHW': @@ -148,7 +148,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs): s: Schedule The computation schedule for conv2d. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if 'cudnn' in target.libs: return generic.schedule_extern(outs) @@ -186,7 +186,7 @@ def schedule_conv2d_nhwc_cuda(cfg, outs): s: Schedule The computation schedule for conv2d. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if 'cudnn' in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/cuda/conv2d_direct.py b/topi/python/topi/cuda/conv2d_direct.py index d64712550855..b7df88579f49 100644 --- a/topi/python/topi/cuda/conv2d_direct.py +++ b/topi/python/topi/cuda/conv2d_direct.py @@ -34,7 +34,7 @@ def schedule_direct_cuda(cfg, s, conv): cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index 26bc26169674..be9f31567bc9 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -170,7 +170,7 @@ def _callback(op): cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index dfa569a556ce..37307d62357d 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -194,7 +194,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed): cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: @@ -325,7 +325,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): Unlike other TOPI functions, this function operates on both graph level and operator level, so we have to pass 'F' to make it support our two versions of graph IR, Relay. """ - if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs: + if 'cudnn' in tvm.target.Target.current().libs or 'miopen' in tvm.target.Target.current().libs: return None copy_inputs = list(inputs) @@ -349,7 +349,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): CO, _, KH, KW = get_const_tuple(kernel.shape) dispatch_ctx = autotvm.DispatchContext.current - target = tvm.target.current_target() + target = tvm.target.Target.current() if groups == 1: # query config of this workload diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py index 7d3c0b4afc1b..b46f284ef5b7 100644 --- a/topi/python/topi/cuda/conv3d.py +++ b/topi/python/topi/cuda/conv3d.py @@ -64,7 +64,7 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o output : tvm.Tensor 5-D with shape [batch, out_channel, out_depth, out_height, out_width] """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cudnn" in target.libs: if layout == 'NCDHW': @@ -126,7 +126,7 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs): s: Schedule The computation schedule for conv2d. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if 'cudnn' in target.libs: return generic.schedule_extern(outs) @@ -160,7 +160,7 @@ def schedule_conv3d_ndhwc_cuda(cfg, outs): s: Schedule The computation schedule for conv2d. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if 'cudnn' in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/cuda/conv3d_direct.py b/topi/python/topi/cuda/conv3d_direct.py index e38dbcbfa002..ad48deb27539 100644 --- a/topi/python/topi/cuda/conv3d_direct.py +++ b/topi/python/topi/cuda/conv3d_direct.py @@ -36,7 +36,7 @@ def schedule_direct_3d_cuda(cfg, s, conv): cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/deformable_conv2d.py b/topi/python/topi/cuda/deformable_conv2d.py index a0e1cb8f5fc6..33a8c9adc1ca 100644 --- a/topi/python/topi/cuda/deformable_conv2d.py +++ b/topi/python/topi/cuda/deformable_conv2d.py @@ -67,7 +67,7 @@ def schedule_direct_cuda(cfg, s, conv): cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/dense.py b/topi/python/topi/cuda/dense.py index f17feb0b7a98..1a1af703c55c 100644 --- a/topi/python/topi/cuda/dense.py +++ b/topi/python/topi/cuda/dense.py @@ -60,7 +60,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None): out_dtype = data.dtype batch, in_dim = data.shape out_dim, _ = weight.shape - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cublas" in target.libs: matmul = cublas.matmul(data, weight, False, True, out_dtype) if bias is not None: @@ -87,7 +87,7 @@ def schedule_dense(cfg, outs): The computation schedule for dense. """ # pylint: disable=unused-argument - target = tvm.target.current_target() + target = tvm.target.Target.current() outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs if target.target_name == "cuda" and "cublas" in target.libs: @@ -259,7 +259,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): batch, in_dim = get_const_tuple(data.shape) out_dim, _ = get_const_tuple(weight.shape) - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cublas" in target.libs: matmul = cublas.matmul(data, weight, False, True, out_dtype) if bias is not None: @@ -290,7 +290,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): def schedule_dense_int8(cfg, outs): """Dense schedule for int8 on CUDA""" s = tvm.create_schedule([x.op for x in outs]) - target = tvm.target.current_target() + target = tvm.target.Target.current() outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs if "cublas" in target.libs: diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index 6dbfbfe39cae..05e1117ac2ce 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -57,7 +57,7 @@ def _callback(op): cfg.define_split("tile_x", x, num_outputs=4) cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: @@ -166,7 +166,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): # num_thread here could be 728, it is larger than cuda.max_num_threads num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value - target = tvm.target.current_target() + target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads xoc, xic = s[Output].split(c, factor=num_thread) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index f4bb73470651..54e8427daf79 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -340,7 +340,7 @@ def schedule_group_conv2d_nchw_direct(cfg, s, conv): cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index 0a131148be68..b77a97924716 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out): The updated schedule. """ fused = sch[out].fuse(*sch[out].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 try: diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 8ddb44efc738..38f87a9523c8 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -71,7 +71,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -120,7 +120,7 @@ def get_valid_counts_upsweep(data, idx_in, idx, partial): idx_in = ib.buffer_ptr(idx_in) idx = ib.buffer_ptr(idx) partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads nthread_bx = batch_size @@ -176,7 +176,7 @@ def get_valid_counts_scan(data, partial_in, partial): ib = tvm.ir_builder.create() partial_in = ib.buffer_ptr(partial_in) partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads nthread_bx = batch_size @@ -234,7 +234,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): idx_in = ib.buffer_ptr(idx_in) idx = ib.buffer_ptr(idx) partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 @@ -297,7 +297,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -356,7 +356,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """ batch_size = data.shape[0] num_anchors = data.shape[1] - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 new_range = num_anchors // elem_per_thread + 1 temp_flag_buf = api.decl_buffer( @@ -482,7 +482,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) + tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -594,7 +594,7 @@ def invalid_to_bottom_pre(data, flag, idx): idx = ib.buffer_ptr(idx) max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + tvm.target.Target.current(allow_none=False).max_num_threads)) nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -654,7 +654,7 @@ def invalid_to_bottom_ir(data, flag, idx, out): out = ib.buffer_ptr(out) max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + tvm.target.Target.current(allow_none=False).max_num_threads)) nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") diff --git a/topi/python/topi/cuda/nn.py b/topi/python/topi/cuda/nn.py index a5c310eb1a45..327afa87edb5 100644 --- a/topi/python/topi/cuda/nn.py +++ b/topi/python/topi/cuda/nn.py @@ -37,6 +37,6 @@ def schedule_lrn(outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.cuda.schedule_lrn(cpp_target, outs) diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index f11085afb3cd..2bf1e6bb9ef0 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -112,7 +112,7 @@ def schedule_pool(outs, layout): def _schedule(PaddedInput, Pool): if isinstance(PaddedInput.op, tvm.tensor.ComputeOp): s[PaddedInput].compute_inline() - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = tvm.target.Target.current(allow_none=False).max_num_threads if Pool.op in s.outputs: Out = Pool OL = s.cache_write(Pool, "local") @@ -177,7 +177,7 @@ def _schedule_pool_grad(op): else: out = outs[0].op.output(0) fused = s[out].fuse(*s[out].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = tvm.target.Target.current(allow_none=False).max_num_threads bx, tx = s[out].split(fused, factor=num_thread) s[out].bind(bx, tvm.thread_axis("blockIdx.x")) s[out].bind(tx, tvm.thread_axis("threadIdx.x")) diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 7567c651772b..4344226d787e 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -64,7 +64,7 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r """ batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape) num_anchors //= 2 - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = (batch * height * width) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -152,7 +152,7 @@ def argsort_ir(data_buf, out_index_buf): The result IR statement. """ batch, num_bbox = get_const_tuple(data_buf.shape) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() p_data = ib.buffer_ptr(data_buf) index_out = ib.buffer_ptr(out_index_buf) @@ -225,7 +225,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return i / u batch, num_bbox = get_const_tuple(out_buf.shape) - max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int(math.sqrt(tvm.target.Target.current(allow_none=False).max_num_threads)) tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib = tvm.ir_builder.create() diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index 2968ab75e040..69c685cb50b4 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -35,7 +35,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): if len(sch[data_out].op.axis) > 0: all_reduce = False num_thread = 32 - target = tvm.target.current_target() + target = tvm.target.Target.current() if target and target.target_name == "opencl": # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why @@ -45,7 +45,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") else: all_reduce = True - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = tvm.target.Target.current(allow_none=False).max_num_threads thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") # Fuse and refactor the reduce axis diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 0e7a23eb14ab..b32cce75362f 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -87,7 +87,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): axis_mul_before *= value elif i > axis: axis_mul_after *= value - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() data = ib.buffer_ptr(data) values_out = ib.buffer_ptr(values_out) @@ -186,7 +186,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): axis_mul_before *= value elif i > axis: axis_mul_after *= value - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() data = ib.buffer_ptr(data) valid_count = ib.buffer_ptr(valid_count) diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index e1af4365520e..10ba7a1051ea 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -60,7 +60,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): The result IR statement. """ max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + tvm.target.Target.current(allow_none=False).max_num_threads)) tx = tvm.thread_axis("threadIdx.x") ty = tvm.thread_axis("threadIdx.y") bx = tvm.thread_axis("blockIdx.x") @@ -196,7 +196,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -307,7 +307,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, score = ib.buffer_ptr(temp_score) out_loc = ib.buffer_ptr(out) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 2df273ff50e3..d456aadf4f5e 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -53,7 +53,7 @@ def schedule_reorg(outs): s: Schedule The computation schedule for reorg. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.cuda.schedule_injective(cpp_target, outs) diff --git a/topi/python/topi/generic/extern.py b/topi/python/topi/generic/extern.py index a0601147d5e8..e895385e8b66 100644 --- a/topi/python/topi/generic/extern.py +++ b/topi/python/topi/generic/extern.py @@ -36,5 +36,5 @@ def schedule_extern(outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() return cpp.generic.schedule_extern(target, outs) diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py index 178363dc0d4f..2aff96f9636c 100644 --- a/topi/python/topi/generic/injective.py +++ b/topi/python/topi/generic/injective.py @@ -54,7 +54,7 @@ def schedule_injective(outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) if target.target_name != "llvm": raise RuntimeError("schedule_injective not registered for '%s'" % target) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index be9e54e97a1e..883182941202 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -22,7 +22,7 @@ def _default_schedule(outs, auto_inline): """Default schedule for llvm.""" - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs if target.target_name not in ("llvm", "c"): raise RuntimeError("schedule not registered for '%s'" % target) @@ -645,7 +645,7 @@ def schedule_lrn(outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.generic.default_schedule(cpp_target, outs, False) @@ -686,6 +686,6 @@ def schedule_sparse_transpose(outs): @tvm.target.generic_func def schedule_batch_matmul(outs): - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.generic.default_schedule(cpp_target, outs, False) diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py index a1e096a85880..85d9153e6424 100644 --- a/topi/python/topi/generic/vision.py +++ b/topi/python/topi/generic/vision.py @@ -22,7 +22,7 @@ def _default_schedule(outs, auto_inline): """Default schedule for llvm.""" - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs if target.target_name != "llvm": raise RuntimeError("schedule not registered for '%s'" % target) @@ -48,7 +48,7 @@ def schedule_reorg(outs): s: Schedule The computation schedule for the op. """ - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.generic.default_schedule(cpp_target, outs, False) diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index f02eb497f519..65ea590905f9 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -221,7 +221,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): return None dispatch_ctx = autotvm.task.DispatchContext.current - target = tvm.target.current_target() + target = tvm.target.Target.current() # query schedule and fallback if necessary workload = autotvm.task.args_to_workload( diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py index c747c539d7fe..97b7376933de 100644 --- a/topi/python/topi/intel_graphics/depthwise_conv2d.py +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -59,7 +59,7 @@ def _callback(op): cfg.define_split("tile_x", x, num_outputs=4) cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: @@ -167,7 +167,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): # num_thread here could be 728, it is larger than cuda.max_num_threads num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value - target = tvm.target.current_target() + target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads xoc, xic = s[Output].split(c, factor=num_thread) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ea4661f7602e..35a86e991c23 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -153,7 +153,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): # this part to make tuning records correct s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') else: - max_threads = tvm.target.current_target(allow_none=False).max_num_threads + max_threads = tvm.target.Target.current(allow_none=False).max_num_threads co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) fused, vec = s[kernel_vec].split(fused, VC) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index abdb5f22e369..52f4b12a1d2d 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -465,7 +465,7 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index 0a41838aa50e..be29c6f6b0cc 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -57,7 +57,7 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou 4-D with shape [batch, out_channel, out_height, out_width] """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if "miopen" in target.libs: assert layout == 'NCHW', "Only NCHW layout is supported." CO, CI, KH, KW = get_const_tuple(kernel.shape) @@ -106,7 +106,7 @@ def schedule_conv2d_nchw_rocm(cfg, outs): s: Schedule The computation schedule for conv2d. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if target and "miopen" in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/rocm/dense.py b/topi/python/topi/rocm/dense.py index 6fca7cd79656..f2adeaabef61 100644 --- a/topi/python/topi/rocm/dense.py +++ b/topi/python/topi/rocm/dense.py @@ -56,7 +56,7 @@ def dense_rocm(cfg, data, weight, bias=None, out_dtype=None): out_dtype = data.dtype batch, in_dim = data.shape out_dim, _ = weight.shape - target = tvm.target.current_target() + target = tvm.target.Target.current() if "rocblas" in target.libs: assert out_dtype == data.dtype, "Mixed precision not supported." matmul = rocblas.matmul(data, weight, False, True) @@ -83,7 +83,7 @@ def schedule_dense(cfg, outs): s: Schedule The computation schedule for dense. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if target.target_name == "rocm" and "rocblas" in target.libs: return generic.schedule_extern(outs) return topi.cuda.schedule_dense(cfg, outs) diff --git a/topi/python/topi/rocm/nn.py b/topi/python/topi/rocm/nn.py index bb6a8bf43557..8a9c8c393da6 100644 --- a/topi/python/topi/rocm/nn.py +++ b/topi/python/topi/rocm/nn.py @@ -23,6 +23,6 @@ @generic.schedule_lrn.register(["rocm", "gpu"]) def schedule_lrn(outs): - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.rocm.schedule_lrn(cpp_target, outs) diff --git a/topi/python/topi/x86/batch_matmul.py b/topi/python/topi/x86/batch_matmul.py index 25b49d12400d..fef6c48d6bed 100644 --- a/topi/python/topi/x86/batch_matmul.py +++ b/topi/python/topi/x86/batch_matmul.py @@ -43,7 +43,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y): output : tvm.Tensor 3-D with shape [batch, M, N] """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: return cblas.batch_matmul(x, y, False, True) @@ -83,7 +83,7 @@ def schedule_batch_matmul(cfg, outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 1ba4f68be6c4..95ce3376ac3a 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -74,7 +74,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): kh, kw, oc, _ = kshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape assert ic_chunk == k_ic_chunk assert ic_bn == k_ic_bn @@ -423,7 +423,7 @@ def traverse(op): data = data_pad.op.input_tensors[0] args = [s, cfg, data_vec, conv_out, outs[0]] - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) _, _, kh, kw, _, _, = get_const_tuple(kernel.shape) if kh == 1 and kw == 1: conv2d_avx_1x1._schedule_conv_NCHWc(*args) diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 9374387fb23a..8b0c13c2c0bb 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -75,7 +75,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): # Set workload. Config update. dispatch_ctx = autotvm.task.DispatchContext.current - target = tvm.target.current_target() + target = tvm.target.Target.current() if is_depthwise: workload = autotvm.task.args_to_workload( diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index 1701643844e1..20712d2f6f4f 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -64,11 +64,11 @@ def _is_int8_hw_support(data_dtype, kernel_dtype): is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' # 2) Check LLVM support - llvm_version = tvm.codegen.llvm_version_major() + llvm_version = tvm.target.codegen.llvm_version_major() is_llvm_support = llvm_version >= 8 # 3) Check target - mcpu = tvm.target.current_target().mcpu + mcpu = tvm.target.Target.current().mcpu is_target_support = False if mcpu in ('skylake-avx512', 'cascadelake'): is_target_support = True @@ -89,7 +89,7 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay kh, kw, oc, _ = kshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape ic = ic_chunk * ic_bn assert ic == k_ic * k_ic_f * k_ic_s @@ -205,7 +205,7 @@ def traverse(op): data = data_pad.op.input_tensors[0] args = [s, cfg, data_vec, conv_out, outs[0]] - target = tvm.target.current_target(allow_none=False) + target = tvm.target.Target.current(allow_none=False) # int8 conv kernel is 7-dim _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) if kh == 1 and kw == 1: diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index dd1822f0fd73..c6c3d5e667ac 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -28,7 +28,7 @@ @autotvm.register_topi_compute(nn.dense, "cpu", "direct") def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: C = cblas.matmul(data, weight, False, True) if bias is not None: @@ -119,7 +119,7 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct") def _schedule_dense(cfg, outs): - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: return generic.schedule_extern(outs) @@ -136,7 +136,7 @@ def _callback(op): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") def _schedule_dense_pack(cfg, outs): - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: return generic.schedule_extern(outs) @@ -151,7 +151,7 @@ def _callback(op): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") def _schedule_dense_nopack(cfg, outs): - target = tvm.target.current_target() + target = tvm.target.Target.current() if "cblas" in target.libs: return generic.schedule_extern(outs) diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index a8ad251115d7..dc9e1456d2cd 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -17,11 +17,12 @@ """Core kernel of dot product of 4 Int8 operations""" #pylint: disable=invalid-name import tvm +import tvm.target.codegen def dot_16x1x16_uint8_int8_int32(): """Dispatch the most optimized intrin depending on the target""" - mcpu = tvm.target.current_target().mcpu + mcpu = tvm.target.Target.current().mcpu assert mcpu in ("skylake-avx512", "cascadelake"), \ "An old Intel machine that does not have fast Int8 support." @@ -254,7 +255,7 @@ def _instr(index): vec_b = ins[1].vload([0, 0], "int8x64") vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512' - llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) + llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b) diff --git a/topi/python/topi/x86/util.py b/topi/python/topi/x86/util.py index aff37aa36202..04931f577b51 100644 --- a/topi/python/topi/x86/util.py +++ b/topi/python/topi/x86/util.py @@ -19,7 +19,7 @@ import tvm def get_fp32_len(): - mcpu = tvm.target.current_target().mcpu + mcpu = tvm.target.Target.current().mcpu fp32_vec_len = 8 if mcpu in ('skylake-avx512', 'cascadelake'): fp32_vec_len = 16 diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index ae77f00fb8a9..bf6409cc9405 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -84,7 +84,7 @@ def compute_conv2d(attrs, inputs, output_type, target): groups, out_dtype)] # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): + with tvm.target.arm_cpu(tvm.target.Target.current().model): return _nn.compute_conv2d(attrs, inputs, output_type, target) # If VTA is not the target, default to _nn def @@ -105,8 +105,8 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_group_conv2d_nchw(outs) # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): - return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target()) + with tvm.target.arm_cpu(tvm.target.Target.current().model): + return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current()) # If VTA is not the target, default to _nn def return _nn.schedule_conv2d(attrs, outs, target) @@ -128,7 +128,7 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target): return [topi.nn.conv2d_transpose_nchw( inputs[0], inputs[1], strides, padding, out_dtype)] # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): + with tvm.target.arm_cpu(tvm.target.Target.current().model): return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target) # If VTA is not the target, default to _nn def @@ -145,11 +145,11 @@ def schedule_conv2d_transpose(attrs, outputs, target): if is_packed_layout(layout): return topi.nn.schedule_conv2d_transpose_nchw(outputs) # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): - return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) + with tvm.target.arm_cpu(tvm.target.Target.current().model): + return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current()) # If VTA is not the target, default to _nn def - return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) + return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current()) @reg.register_compute("nn.dense", level=15) @@ -163,7 +163,7 @@ def compute_dense(attrs, inputs, out_type, target): target = tvm.target.create(target) return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)] # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): + with tvm.target.arm_cpu(tvm.target.Target.current().model): return _nn.compute_dense(attrs, inputs, out_type, target) # If VTA is not the target, default to _nn def @@ -179,8 +179,8 @@ def schedule_dense(attrs, outs, target): assert target.device_name == "vta" return topi.generic.schedule_dense(outs) # If it's not packed, run on ARM CPU - with tvm.target.arm_cpu(tvm.target.current_target().model): - return _nn.schedule_dense(attrs, outs, tvm.target.current_target()) + with tvm.target.arm_cpu(tvm.target.Target.current().model): + return _nn.schedule_dense(attrs, outs, tvm.target.Target.current()) # If VTA is not the target, default to _nn def return _nn.schedule_dense(attrs, outs, target) diff --git a/vta/scripts/tune_conv2d.py b/vta/scripts/tune_conv2d.py index 2780f26ca57c..265a6392b054 100644 --- a/vta/scripts/tune_conv2d.py +++ b/vta/scripts/tune_conv2d.py @@ -80,7 +80,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation): res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_conv2d_nchw([res]) else: s = tvm.create_schedule([res.op]) diff --git a/vta/scripts/tune_conv2d_transpose.py b/vta/scripts/tune_conv2d_transpose.py index f779b76f8277..d6475abff667 100644 --- a/vta/scripts/tune_conv2d_transpose.py +++ b/vta/scripts/tune_conv2d_transpose.py @@ -68,7 +68,7 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_conv2d_transpose_nchw([res]) else: s = tvm.create_schedule([res.op]) diff --git a/vta/scripts/tune_dense.py b/vta/scripts/tune_dense.py index 7813b00fc878..fa49be7f9b27 100644 --- a/vta/scripts/tune_dense.py +++ b/vta/scripts/tune_dense.py @@ -59,7 +59,7 @@ def dense(N, CI, CO): res = my_clip(res, 0, 127) res = topi.cast(res, "int8") - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_dense([res]) else: s = tvm.create_schedule([res.op]) diff --git a/vta/scripts/tune_group_conv2d.py b/vta/scripts/tune_group_conv2d.py index c578090e26aa..555154d708fc 100644 --- a/vta/scripts/tune_group_conv2d.py +++ b/vta/scripts/tune_group_conv2d.py @@ -80,7 +80,7 @@ def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group): res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_group_conv2d_nchw([res]) else: s = tvm.create_schedule([res.op]) diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py index 9d8ed8980bb0..b9edc30e5ba3 100644 --- a/vta/scripts/tune_resnet.py +++ b/vta/scripts/tune_resnet.py @@ -84,7 +84,7 @@ def _topi_nn_conv2d(*args, **kwargs): res = my_clip(res, 0, 127) res = topi.cast(res, "int8") - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_conv2d_nchw([res]) else: s = tvm.create_schedule([res.op]) @@ -102,7 +102,7 @@ def _topi_nn_dense(*args, **kwargs): res = my_clip(res, 0, 127) res = topi.cast(res, "int8") - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_dense([res]) else: s = tvm.create_schedule([res.op]) diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 3221c3b77b1f..94fba3db2989 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -321,7 +321,7 @@ def _topi_nn_conv2d(*args, **kwargs): res = my_clip(res, 0, 127) res = topi.cast(res, "int8") - if tvm.target.current_target().device_name == 'vta': + if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_conv2d_nchw([res]) else: s = tvm.create_schedule([res.op])