Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,14 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
PrimExpr extent = arith::Analyzer().Simplify(stop - start);
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};
DataType dtype = DataType(min.dtype().code(), bits, 1);
n->vars = {Var("v", dtype)};
n->doms = {Range::FromMinExtent(min, extent)};
n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
n->f_make_for_loop = [annotations, thread, dtype](Array<Var> vars, Array<Range> doms,
Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
thread);
IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
annotations.value_or(Map<String, ObjectRef>()));
};
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
/*kind=*/ForKind::kThreadBinding, //
/*body=*/body, //
/*thread_binding=*/
IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
IterVar(NullValue<Range>(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex,
"threadIdx." + dim_index));
}
return body;
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,20 @@ def evaluated(A: T.Buffer((2, 128, 128), "int32")):
tvm.ir.assert_structural_equal(with_builtin, evaluated)


def test_thread_binding_dtype():
@T.prim_func(private=True)
def func(A: T.Buffer((128, 128)), B: T.Buffer((128, 128))):
for i in T.thread_binding(T.int64(128), "threadIdx.x"):
for j in T.thread_binding(128, "threadIdx.y"):
B[i, j] = A[i, j]

loop_i = func.body
loop_j = loop_i.body
assert loop_i.loop_var.dtype == "int64"
assert loop_i.thread_binding.var.dtype == "int64"
assert loop_j.loop_var.dtype == "int32"
assert loop_j.thread_binding.var.dtype == "int32"


if __name__ == "__main__":
tvm.testing.main()