Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
return actx->Simplify(indexdiv(a + (b - 1), b));
};

auto& state = *p_state;
// forwar iteration on relations
// forward iteration on relations
for (IterVarRelation rel : stage->relations) {
if (const SplitNode* r = rel.as<SplitNode>()) {
if (!state.count(r->parent)) {
Expand All @@ -73,11 +73,16 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
PrimExpr outer_extent = ceil_div(range_parent->extent, r->factor);
if (is_one(outer_extent)) {
Update(p_state, r->inner,
Range::make_by_min_extent(0, range_parent->extent), actx);
} else {
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
}
Update(p_state, r->outer,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->factor)), actx);
Range::make_by_min_extent(0, outer_extent) , actx);
} else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
Update(p_state, r->inner,
Expand Down
18 changes: 0 additions & 18 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,6 @@ def test_condition_EQ():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))

def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.size_var('m')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=32)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = lower(s, [A, B])
for_body = stmt.body.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent))

def test_everything_during_deduction():
m = tvm.size_var('m')
n = tvm.size_var('n')
Expand Down Expand Up @@ -455,7 +438,6 @@ def test_simple_rfactor():
test_vectorize()
test_condition()
test_condition_EQ()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
test_multi_likely()
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,42 @@ def test_bound_fusesplit2():
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3)


def test_bound_split_compute_at():
N = M = K = 1024

A = tvm.placeholder((N, K), name='A')
B = tvm.placeholder((M, K), name='B')
C = tvm.compute((N, M), lambda i, j: A[i][j] * 2.0, name='C')
D = tvm.compute((N, M), lambda i, j: C[i][j] + 2.0, name='D')

s = tvm.create_schedule([D.op])

def multiple_split(X):
i, j = s[X].op.axis
i2, i3 = s[X].split(i, 4)
i1, i2 = s[X].split(i2, 4)
i0, i1 = s[X].split(i1, 4)

j2, j3 = s[X].split(j, 4)
j1, j2 = s[X].split(j2, 4)
j0, j1 = s[X].split(j1, 4)

iters = i0, j0, i1, j1, i2, j2, i3, j3
s[X].reorder(*iters)
return iters

c_iters = multiple_split(C)
d_iters = multiple_split(D)

s[C].compute_at(s[D], d_iters[-3])

bounds = tvm.schedule.InferBound(s)
for i in range(0, 6):
assert bounds[c_iters[i]].extent.value == 1
for i in range(6, 8):
assert bounds[c_iters[i]].extent.value == 4


def test_bound_warp():
m = tvm.var('m')
l = tvm.var('l')
Expand Down Expand Up @@ -422,5 +458,6 @@ def _check(B, A=A):
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
test_bound_split_compute_at()
test_bound_split_divisible()
test_bound_tile_divisible()