Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,53 @@


def min_value(dtype):
"""minimum value of dtype"""
"""minimum value of dtype

Parameters
----------
dtype : str
The data type.

Returns
-------
value : tvm.Expr
The minimum value of dtype.
"""
return _api_internal._min_value(dtype)


def max_value(dtype):
"""maximum value of dtype"""
"""maximum value of dtype

Parameters
----------
dtype : str
The data type.

Returns
-------
value : tvm.Expr
The maximum value of dtype.
"""
return _api_internal._max_value(dtype)


def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, _Integral):
dtype = 'int32'
else:
dtype = 'float32'
def const(value, dtype):
"""construct a constant

Parameters
----------
value : number
The content of the constant number.

dtype : str
The data type.

Returns
-------
const_val: tvm.Expr
The result expression.
"""
return _api_internal._const(value, dtype)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def bind(func_id, args):
_internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0])
low, ext = _api.const(0), args[1]
low, ext = _api.const(0, "int32"), args[1]
for_type = None
return iter_var, low, ext, for_type

Expand Down
23 changes: 17 additions & 6 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import operator
import logging
import sys
from numbers import Integral

from .util import _internal_assert
from . import calls
from . import util
Expand Down Expand Up @@ -137,6 +139,15 @@ def _get_buffer_from_id(self, s, for_provide=False):
return self._args[s]
return self.alloc_buffers[s][0]

def _const(self, value, dtype=None):
if dtype is None:
if isinstance(value, bool):
dtype = "bool"
elif isinstance(value, Integral):
dtype = "int32"
else:
dtype = "float32"
return _api.const(value, dtype)

#pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node):
Expand Down Expand Up @@ -172,9 +183,9 @@ def visit_Name(self, node):
if isinstance(res, tuple):
buf = res[0]
if isinstance(node.ctx, ast.Load):
return _make.Call(buf.dtype, buf.name, [_api.const(0)], \
return _make.Call(buf.dtype, buf.name, [self._const(0)], \
_expr.Call.Halide, buf.op, buf.value_index)
return buf, [_api.const(0)]
return buf, [self._const(0)]
if isinstance(node.ctx, ast.Load):
return res
return None
Expand All @@ -183,7 +194,7 @@ def visit_Name(self, node):


def visit_Num(self, node):
return _api.const(node.n)
return self._const(node.n)


def visit_AugAssign(self, node):
Expand All @@ -193,7 +204,7 @@ def visit_AugAssign(self, node):
_internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
buf, args = buf
else:
args = [_api.const(0)]
args = [self._const(0)]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")

read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
Expand Down Expand Up @@ -378,7 +389,7 @@ def visit_For(self, node):
if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0)):
if not _ir_pass.Equal(low, self._const(0)):
offset = iter_var + low
self.loops_above[_name] = offset
else:
Expand All @@ -389,7 +400,7 @@ def visit_For(self, node):
if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else:
res = _make.For(iter_var, _api.const(0), ext, for_type, 0, _body)
res = _make.For(iter_var, self._const(0), ext, for_type, 0, _body)
self.loops_above.pop(_name)
return res

Expand Down
6 changes: 0 additions & 6 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,6 @@ def const(value, dtype=None):
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
# convert default to int32 and float32
if dtype is None:
if value.dtype == "float64":
value = value.astype("float32")
elif value.dtype == "int64":
value = value.astype("int32")
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)

Expand Down
11 changes: 6 additions & 5 deletions tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_tuple_value():
def test_id():
x = relay.var('x', 'float32')
ident = relay.Function([x], x)
check_eval(ident, [1.0], 1.0)
one = np.array(1.0, 'float32')
check_eval(ident, [one], one)


def test_add_const():
Expand All @@ -60,8 +61,8 @@ def test_equal():
j = relay.var('i', shape=[], dtype='int32')
z = relay.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0)
j_data = relay.const(0)
i_data = relay.const(0, 'int32')
j_data = relay.const(0, 'int32')
check_eval(func, [i_data, j_data], True)


Expand Down Expand Up @@ -96,10 +97,10 @@ def test_loop():
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0))):
with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
sb.ret(accum)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1))
one_less = relay.subtract(i, relay.const(1, 'int32'))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
Expand Down
5 changes: 3 additions & 2 deletions tests/python/relay/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(1) })
result = ex.evaluate(prog, { x: const(1, 'int32') })
assert _test_debug_hit
assert result.asnumpy() == 1


def test_debug_with_expr():
global _test_debug_hit
_test_debug_hit = False
Expand All @@ -27,6 +28,6 @@ def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x + x * x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(2) })
result = ex.evaluate(prog, { x: const(2, 'int32') })
assert _test_debug_hit
assert result.asnumpy() == 6
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def verify_full(fill_value, src_shape, dtype):
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(fill_value)
op_res = intrp.evaluate(func)(np.array(fill_value, dtype))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full(4, (1, 3, 4, 4), "int32")
verify_full(4.0, (1, 4), "float32")
Expand Down Expand Up @@ -365,7 +365,7 @@ def verify_full_like(base, fill_value, dtype):
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, fill_value)
op_res = intrp.evaluate(func)(x_data, np.array(fill_value, dtype))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full_like((1, 3, 4, 4), 4, "int32")
verify_full_like((1, 1), 44.0, "float32")
Expand Down
5 changes: 2 additions & 3 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def before():
@register_alter_op_layout("nn.conv2d", level=100)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0))
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0)),
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
Expand Down Expand Up @@ -313,4 +313,3 @@ def expected():
test_alter_layout_dual_path()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()

15 changes: 8 additions & 7 deletions tests/python/unittest/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_simplify():
assert zz.a == x and zz.b.value == 4

n = tvm.var('n')
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
tvm.ir_pass.CanonicalSimplify(n / (-1))
# This is not true in the current implementation
Expand Down Expand Up @@ -67,10 +67,11 @@ def test_modular():
ry = tvm.var("ry")
y = tvm.var("y")
x = tvm.var("x")
vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)),
ry: tvm.Range(tvm.const(0), tvm.const(3)),
y: tvm.Range(tvm.const(0), tvm.const(2)),
x: tvm.Range(tvm.const(0), tvm.const(14))}
i32_const = lambda x: tvm.const(x, "int32")
vmap = {rx: tvm.Range(i32_const(0), i32_const(3)),
ry: tvm.Range(i32_const(0), i32_const(3)),
y: tvm.Range(i32_const(0), i32_const(2)),
x: tvm.Range(i32_const(0), i32_const(14))}
idx = ry * 16 + rx + y * 16 + x
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
Expand All @@ -82,4 +83,4 @@ def test_modular():
test_modular()
test_simplify()
test_mul()
test_simplify_minmax()
test_simplify_minmax()
6 changes: 3 additions & 3 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import tvm

def test_const():
x = tvm.const(1)
x = tvm.const(1, "int32")
print(x.dtype)
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)

def test_make():
x = tvm.const(1)
x = tvm.const(1, "int32")
y = tvm.var("x")
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)

def test_ir():
x = tvm.const(1)
x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1)
z = x + y
stmt = tvm.make.Evaluate(z)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_lang_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x) for x in args])
x = f(*[tvm.const(x, "int32") for x in args])
y = f(*args)
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_lang_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

def test_const_saveload_json():
# save load json
x = tvm.const(1)
y = tvm.const(10)
x = tvm.const(1, "int32")
y = tvm.const(10, "int32")
z = x + y
z = z + z
json_str = tvm.save_json(z)
Expand All @@ -13,8 +13,8 @@ def test_const_saveload_json():

def test_make_smap():
# save load json
x = tvm.const(1)
y = tvm.const(10)
x = tvm.const(1, "int32")
y = tvm.const(10, "int32")
z = tvm.expr.Add(x, y)
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap]))
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_pass_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def test_basic():

def test_bound():
m = tvm.var('m')
vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))})
vrange = tvm.convert({m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))})
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m

def test_canonical():
x = tvm.var("x")
z = tvm.const(3)
z = tvm.const(3, "int32")
ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z))
assert(tvm.ir_pass.Equal(ret, 0))

Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_pass_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def test_parallel_alloc():
n = tvm.var("n")
with ib.for_range(0, n, name="t") as i:
ib.scope_attr(
tvm.const(1) , "pragma_scope", tvm.make.StringImm("parallel_launch_point"))
tvm.const(1, "int32") , "pragma_scope",
tvm.make.StringImm("parallel_launch_point"))
with ib.for_range(0, n, name="i", for_type="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_pass_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_unroll_loop():
assert ret.for_type == tvm.stmt.For.Unrolled

ib = tvm.ir_builder.create()
ib.scope_attr(tvm.const(0), "pragma_auto_unroll_max_step", 16)
ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16)
ib.emit(stmt)
wrapped = ib.get()
wrapped = tvm.make.Block(wrapped, stmt)
Expand Down Expand Up @@ -54,4 +54,4 @@ def test_unroll_fake_loop():

if __name__ == "__main__":
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_fake_loop()
3 changes: 2 additions & 1 deletion tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def _compute(*indice):

def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad')
Apad = tvm.compute((66,), lambda i: tvm.select(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
Expand Down
Loading