Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/tvm/hybrid/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
parser = parse_python(src, args)
parser = parse_python(src, func.__globals__, args)

input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)

op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]

return res[0] if len(res) == 1 else res

intersect = _enter_hybrid_runtime(func)
Expand Down
92 changes: 92 additions & 0 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""

from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert

#pylint: disable=redefined-builtin

LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
}

def _range(annotation, args):
"""Handling TVM loop types"""
n = len(args)
if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0]
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type


range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name


def bind(func_id, args):
"""Handling TVM thread binding"""
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0])
low, ext = _api.const(0), args[1]
for_type = None
return iter_var, low, ext, for_type


def _math_intrin(func_id, args):
from .. import intrin
return getattr(intrin, func_id)(*args)

sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name


def _min_max(func_id, args):
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1])


min = max = _min_max #pylint: disable=invalid-name


def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n = len(args)
_internal_assert(isinstance(_api.convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
_internal_assert(isinstance(args[1], str),
"The data type should be an str")
_internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
"The data type should be either int or float!")
dtype = args[1]
else:
dtype = 'float32'
if n > 2:
_internal_assert(isinstance(args[2], str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = args[2]
else:
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)

output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
15 changes: 1 addition & 14 deletions python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""

import numpy
from ..stmt import For

class _range(object):
"""Base class of the loop ranges in hybrid script"""
Expand Down Expand Up @@ -102,15 +101,3 @@ def sigmoid(x):
'sigmoid' : sigmoid,
'popcount' : popcount
}


LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}


MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
Loading