Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ Python API
ndarray
error
ir
target
intrin
tensor
schedule
target
build
function
autotvm
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/target.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ tvm.target
----------
.. automodule:: tvm.target
:members:
:imported-members:
2 changes: 0 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from . import stmt
from . import make
from . import ir_pass
from . import codegen
from . import schedule

from . import ir_builder
Expand All @@ -55,7 +54,6 @@
from . import hybrid
from . import testing
from . import error
from . import datatype


from .api import *
Expand Down
35 changes: 23 additions & 12 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal

tvm_shape_index_t = ctypes.c_int64

Expand Down Expand Up @@ -48,6 +47,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]


class DataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
Expand Down Expand Up @@ -89,11 +89,13 @@ def __init__(self, type_str):
bits = 64
head = ""
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
low, high = head.find('['), head.find(']')
if not low or not high or low >= high:
raise ValueError("Badly formatted custom type string %s" % type_str)
type_name = head[low + 1:high]
self.type_code = _api_internal._datatype_get_type_code(type_name)
self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name)
head = head[high+1:]
else:
raise ValueError("Do not know how to handle type %s" % type_str)
Expand All @@ -102,13 +104,15 @@ def __init__(self, type_str):


def __repr__(self):
# pylint: disable=import-outside-toplevel
if self.bits == 1 and self.lanes == 1:
return "bool"
if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code]
else:
import tvm.runtime._ffi_api
type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code)
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
Expand Down Expand Up @@ -168,28 +172,35 @@ def __init__(self, device_type, device_id):
self.device_type = device_type
self.device_id = device_id

def _GetDeviceAttr(self, device_type, device_id, attr_id):
"""Internal helper function to invoke runtime.GetDeviceAttr"""
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
return tvm.runtime._ffi_api.GetDeviceAttr(
device_type, device_id, attr_id)

@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0

@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 1)

@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 2)

@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 3)

@property
Expand All @@ -203,25 +214,25 @@ def compute_version(self):
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 4)

@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 5)

@property
def max_clock_rate(self):
"""Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 6)

@property
def multi_processor_count(self):
"""Return the number of compute units of device."""
return _api_internal._GetDeviceAttr(
return self._GetDeviceAttr(
self.device_type, self.device_id, 7)

@property
Expand All @@ -233,7 +244,7 @@ def max_thread_dimensions(self):
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
return json.loads(self._GetDeviceAttr(
self.device_type, self.device_id, 8))

def sync(self):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def update(self, target, workload, cfg):
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
target = tvm.target.Target.current()
config = dispatch_ctx.query(target, workload)

# Get conv2d_NCHWc workload from config
Expand Down Expand Up @@ -207,7 +207,7 @@ def _do_reg(myf):

def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.current_target()
tgt = _target.Target.current()
workload = func(*args, **kwargs)
cfg = DispatchContext.current.query(tgt, workload)
if cfg.is_fallback and not cfg.template_key:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

from tvm.runtime import Object, ndarray
from tvm.ir import container
from tvm.target import codegen

from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import codegen
from . import target as _target
from . import make
from .stmt import LoweredFunc
Expand Down Expand Up @@ -602,7 +603,7 @@ def build(inputs,
"LoweredFunc.")

if not isinstance(inputs, (dict, container.Map)):
target = _target.current_target() if target is None else target
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/contrib/clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
# under the License.
"""Util to invoke clang in the system."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import subprocess

from .._ffi.base import py_str
from .. import codegen
from tvm._ffi.base import py_str
import tvm.target
from . import util


Expand All @@ -44,8 +43,8 @@ def find_clang(required=True):
matches the major llvm version that built with tvm
"""
cc_list = []
if hasattr(codegen, "llvm_version_major"):
major = codegen.llvm_version_major()
major = tvm.target.codegen.llvm_version_major(allow_none=True)
if major is not None:
cc_list += ["clang-%d.0" % major]
cc_list += ["clang-%d" % major]
cc_list += ["clang"]
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/contrib/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"""Utility for ROCm backend"""
import subprocess
from os.path import join, exists

from tvm._ffi.base import py_str
import tvm.target

from . import util
from .._ffi.base import py_str
from .. import codegen
from ..api import register_func, convert

def find_lld(required=True):
Expand All @@ -42,8 +44,8 @@ def find_lld(required=True):
matches the major llvm version that built with tvm
"""
lld_list = []
if hasattr(codegen, "llvm_version_major"):
major = codegen.llvm_version_major()
major = tvm.target.codegen.llvm_version_major(allow_none=True)
if major is not None:
lld_list += ["ld.lld-%d.0" % major]
lld_list += ["ld.lld-%d" % major]
lld_list += ["ld.lld"]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def max_num_threads(func_id, args):
_internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
_internal_assert(args.__len__() <= 1, "At most one argument accepted!")
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
res = _tgt.Target.current().max_num_threads
else:
_internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
res = _tgt.Target.current(args[0].value).max_num_threads
return _api.convert(res)
2 changes: 1 addition & 1 deletion python/tvm/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sigmoid(x):

def max_num_threads(allow_none=True):
"""Get max number of threads for GPU targets."""
return target.current_target(allow_none).max_num_threads
return target.Target.current(allow_none).max_num_threads


HYBRID_GLOBALS = {
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
import tvm._ffi
import tvm.codegen
import tvm.target.codegen

from . import make as _make
from .api import convert, const
Expand Down Expand Up @@ -189,7 +189,7 @@ def call_llvm_intrin(dtype, name, *args):
call : Expr
The call expression.
"""
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def get_exec(self):

def _update_target(self, target):
"""Update target."""
target = target if target else tvm.target.current_target()
target = target if target else tvm.target.Target.current()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .backend.vm import VMExecutor

def _update_target(target):
target = target if target else _target.current_target()
target = target if target else _target.Target.current()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def _shift(data, zero_point, out_dtype):

def is_fast_int8_on_intel():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.current_target(allow_none=False)
target = tvm.target.Target.current(allow_none=False)
intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
return intel_supported_arches.intersection(set(target.options))

def is_fast_int8_on_arm():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.current_target(allow_none=False)
target = tvm.target.Target.current(allow_none=False)
return '+v8.2a,+dotprod' in ' '.join(target.options)

########################
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _get_profile_runtime(mod):
func = mod['main']
func = _quantize.CreateStatsCollector(func)

if tvm.target.current_target():
target = tvm.target.current_target()
if tvm.target.Target.current():
target = tvm.target.Target.current()
ctx = tvm.context(target.target_name)
else:
target = 'llvm'
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/relay/quantize/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
# under the License.
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import

from ... import target as _target
import tvm
from .. import expr as _expr
from .. import analysis as _analysis
from ..base import register_relay_node
Expand Down Expand Up @@ -133,7 +131,7 @@ def add_partition_generic(ref_call, new_args, ctx):
@register_partition_function("add")
def add_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
target = _target.current_target()
target = tvm.target.Target.current()
if target and 'cuda' in target.keys:
#TODO(wuwei/ziheng) cuda specific rules
return add_partition_generic(ref_call, new_args, ctx)
Expand Down
Loading