Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BufferReadPosCollector : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";

const Buffer& buffer = op->buffer;
if (buffers_.count(buffer.get())) {
Map<Var, PrimExpr> subst_map;
Expand Down
156 changes: 108 additions & 48 deletions tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,54 +38,114 @@ def _create_context(mod, target) -> TuneContext:
)


@T.prim_func
def tir_matmul(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.S(16, i0 * 4 + i1)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@T.prim_func
def rewritten_tir_matmul(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32")
for ax0, ax1 in T.grid(16, 16):
with T.block("layout_rewrite"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1]
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.spatial(16, i0 * 4 + i1)
vj = T.axis.spatial(16, j)
vk = T.axis.reduce(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]


def test_layout_rewrite():
target = _target()
ctx = _create_context(tir_matmul, target)
sch = tvm.tir.Schedule(tir_matmul, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul)
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
def transform(self):
def inner(mod):
target = Target("cuda", host="llvm")
ctx = TuneContext(
mod=mod,
target=target,
postprocs=[
RewriteLayout(),
],
task_name="test",
)
sch = tvm.tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
return sch.mod

return inner


class TestTIRMatmul(BaseBeforeAfter):
"""Main functionality test

A new block should be inserted to transform the layout, with the
compute block operating on the temporary transformed buffer.
"""

def before(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.S(16, i0 * 4 + i1)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

def expected(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32")
for ax0, ax1 in T.grid(16, 16):
with T.block("layout_rewrite"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1]
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.spatial(16, i0 * 4 + i1)
vj = T.axis.spatial(16, j)
vk = T.axis.reduce(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]


class TestRewrittenBuffersMustOccurWithinBlock(BaseBeforeAfter):
"""Buffers must occur within a Block"""

def before(
A: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [0]})
for i, j in T.grid(16, 16):
T.evaluate(A[i, j])

expected = tvm.TVMError


class TestExtentOne(BaseBeforeAfter):
"""Buffers with dimensions of extent 1 can be transformed

Regression test for a previous bug, in which the removal of
trivial variables resulted in an error in `IndexMap::Inverse`.
"""

def before(
A: T.Buffer[(16, 1), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [0]})
for i, j in T.grid(16, 1):
with T.block("block"):
vi, vj = T.axis.remap("SS", [i, j])
T.evaluate(A[vi, vj])

def expected(A: T.Buffer[(16, 1), "float32"]):
T.func_attr({"layout_free_buffers": [0]})

A_global = T.alloc_buffer([16], dtype="float32")
for ax0, ax1 in T.grid(16, 1):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
A_global[v0] = A[v0, v1]

for i, j in T.grid(16, 1):
with T.block("block"):
vi, vj = T.axis.remap("SS", [i, j])
T.evaluate(A_global[vi])


if __name__ == "__main__":
test_layout_rewrite()
tvm.testing.main()