Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,18 +467,24 @@ def enter_scope(

self.node = node
self.context = context
# generate loop vars
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans)
]
# collect loop infos by calling self.func
call_with_error_reporting(context.report_error, span, self.func, *arg_list)
if len(self.loop_vars) != len(self.loop_info):
if len(loop_var_names) != len(self.loop_info):
self.context.report_error(
f"Inconsistent number of vars and loops, got {len(self.loop_vars)} "
f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
+ f"vs {len(self.loop_info)}",
self.node.span,
)
# generate loop vars
self.loop_vars = []
for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
if not li.begin.dtype.startswith("int"):
raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
if not li.extent.dtype.startswith("int"):
raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))

for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
context.update_symbol(loop_var.name, loop_var, node)
context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,13 @@ def visit_For(self, node):

if iter_var is None:
_internal_assert(kind is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if isinstance(ext, _expr.PrimExpr):
dtype = ext.dtype
elif isinstance(ext, int):
dtype = "int32"
else:
raise NotImplementedError(f"Unsupported type of ext: {type(ext)}")
offset = iter_var = tvm.te.var(_name, dtype=dtype)
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
"""Create a for iteration scope.

Parameters
Expand Down Expand Up @@ -240,6 +240,26 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
self.nidx += 1
self._seq_stack.append([])

# auto infer dtype when it's not specified
def get_dtype(expr):
if isinstance(expr, _expr.PrimExpr):
if not expr.dtype.startswith("int"):
raise NotImplementedError(
f"Infer loop_var dtype failed:"
f" unsupported dtype in loop begin or end {expr.dtype}"
)
return expr.dtype
if isinstance(expr, int):
return "int32"
raise NotImplementedError(
f"Infer loop_var dtype failed:"
f" unsupported dtype in loop begin or end {expr.dtype}"
)

if dtype is None:
dtype = "int64" if "int64" in [get_dtype(begin), get_dtype(end)] else "int32"

loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else (end - begin)

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
# Up Sweep of exclusive scan
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
Expand Down Expand Up @@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype)

with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def assign_j():
with ib.else_scope():
assign_j()

with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width:
width = 2 << (l2_width + lower_lim)
# Define and launch the cuda kernel
with ib.new_scope():
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op));
value_map[iv] = promote_to_bound_dtype(dom->min);
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
nest[i + 1].emplace_back(For(var, 0, promote_to_bound_dtype(dom->extent), kind, no_op));
value_map[iv] = promote_to_bound_dtype(var);
} else {
Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype());
Expand Down
20 changes: 20 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
ICHECK(loop_var.dtype().is_scalar());
ICHECK(body.defined());

// When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them
// without raising errors.
auto try_promote_imm_dtype = [&](const PrimExpr& e) {
ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
<< " Loop variable's dtype (" << loop_var.dtype()
<< ") is narrower than that of `min` or `extent` (" << e.dtype() << ")";
const IntImmNode* a = e.as<IntImmNode>();
if (a && e.dtype().bits() < loop_var.dtype().bits()) {
return make_const(loop_var.dtype(), a->value);
} else {
return e;
}
};

min = try_promote_imm_dtype(min);
extent = try_promote_imm_dtype(extent);

ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype();
ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();

ObjectPtr<ForNode> node = make_object<ForNode>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
std::vector<PrimExpr> iter_values;
// Create loop vars and block vars' binding_value
for (const Range& axis_range : cache_region->region) {
Var loop_var("ax" + std::to_string(loop_vars.size()));
Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype());
loop_vars.push_back(loop_var);
iter_values.push_back(axis_range->min + loop_var);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return For(idx, 0, var_lanes_, ForKind::kSerial, stmt);
return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial,
stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_buffer_vload_nullptr():
buf_load = tvm.tir.expr.BufferLoad(buffer=buf, indices=tvm.runtime.convert([0]))
buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load)
for_loop = tvm.tir.stmt.For(
loop_var=var, kind=0, min_val=0, extent=buf_load, body=buf_load_stmt
loop_var=var, kind=0, min_val=0, extent=tvm.tir.Cast("int32", buf_load), body=buf_load_stmt
)
buf_func = tvm.tir.PrimFunc(params={}, body=for_loop)
mod = tvm.IRModule({"main": buf_func})
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def test_device_ir(A, B):
temp[tx] = Aptr[tx]
depth = tvm.tir.log2(cast(n, "float32"))

with ib.for_range(0, depth) as i:
with ib.for_range(0, cast(tvm.tir.ceil(depth), n.dtype)) as i:
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
d = n >> (i + 1)
with ib.if_scope(tx < d):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_transform_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_convert_ssa():
var_type = ir.PointerType(ir.PrimType(dtype))
v = tir.Var("i1", var_type)
buf = tir.decl_buffer([16], dtype=dtype, data=v)
for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop)
let = tir.LetStmt(v, v, nop)
load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
seq = tir.SeqStmt([for_stmt, for_stmt, load])
seq = tir.SeqStmt([let, let, load])
func = tir.PrimFunc([], seq)
mod = tvm.IRModule({"main": func})
mod = tir.transform.InjectVirtualThread()(
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def test_vectorize_with_if():
assert isinstance(stmt.else_case, tvm.tir.For)


def test_vectorize_with_if_cond_int64():
m = te.size_var("m", dtype="int64")
A = te.placeholder((m,), name="A", dtype="float32")
B = te.compute((m,), lambda i: te.if_then_else(i < 2, A[i], A[i] * 2), name="B")
s = te.create_schedule(B.op)
x, y = s[B].split(B.op.axis[0], factor=4)
s[B].vectorize(y)
f = tvm.build(s, [A, B], "llvm")


def test_vectorize_let():
v = tvm.tir.Var("v", "float32")
ib = tvm.tir.ir_builder.create()
Expand Down