From 841b2a9ef6b937f4785e6f5486caeefc67c953a5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Aug 2022 09:56:03 -0500 Subject: [PATCH 01/11] [TIR] Moved tir.FlattenBuffer to occur before tir.LowerOpaqueBlock For buffers with more than one physical axis, the `axis_separators` are required in order to know which groups of logical axes to fuse into each physical axis. The implementation in `tir.FlattenBuffer` assumed that all buffers were being flattened to a single physical axis. Because `tir.LowerOpaqueBlock` replaces the `BlockNode::alloc_buffers` with `Allocate` nodes, `tir.FlattenBuffer` no longer has access to the axis separators and performs inconsistent flattening for `Allocate` as opposed to `BufferLoad`/`BufferStore`. This was introduced in https://github.com/apache/tvm/pull/12172, which decoupled the lowering/flattening steps. The commit reorders the `tir.FlattenBuffer` to occur before `tir.LowerOpaqueBlock`, to make use of the axis separators. Any `Allocate` nodes that exist at that point (e.g. from hand-written schedules) are still flattened to 1-d physical buffers, but the `BlockNode::alloc_buffers` are flattened according to the axis separators. --- src/driver/driver_api.cc | 2 +- src/tir/transforms/flatten_buffer.cc | 54 +++ .../test_tir_transform_flatten_buffer.py | 406 +++++++++--------- 3 files changed, 262 insertions(+), 200 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e528686d967d..ac02f151965b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -203,8 +203,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 22aef136bcff..a14331ccdc64 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -53,6 +53,34 @@ class BufferFlattener : public StmtExprMutator { } } + Stmt VisitStmt_(const BlockNode* op) final { + ICHECK_EQ(op->match_buffers.size(), 0) + << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " + << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; + + Block block = GetRef(op); + + Array alloc_buffers = op->alloc_buffers; + alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } + + Array reads = op->reads; + reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!reads.same_as(op->reads)) { + block.CopyOnWrite()->reads = reads; + } + + Array writes = op->writes; + writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!writes.same_as(op->writes)) { + block.CopyOnWrite()->writes = writes; + } + + return StmtExprMutator::VisitStmt_(block.get()); + } + Stmt VisitStmt_(const AllocateNode* op) final { Allocate alloc = Downcast(StmtExprMutator::VisitStmt_(op)); // TODO(Lunderberg): Move the handling of boolean into a @@ -141,6 +169,32 @@ class BufferFlattener : public StmtExprMutator { return node; } + BufferRegion MutateBufferRegion(BufferRegion region) { + Buffer orig_buf = region->buffer; + Buffer flattened_buf = GetFlattenedBuffer(orig_buf); + if (flattened_buf.same_as(orig_buf)) { + return region; + } + + Array min_values; + Array max_values; + for (const auto& range : region->region) { + min_values.push_back(range->min); + max_values.push_back(range->min + range->extent - 1); + } + + Array flattened_min = orig_buf->ElemOffset(min_values); + Array flattened_max = orig_buf->ElemOffset(max_values); + + Array flattened_ranges; + ICHECK_EQ(flattened_min.size(), flattened_max.size()); + for (size_t i = 0; i < flattened_min.size(); i++) { + flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); + } + + return BufferRegion(flattened_buf, flattened_ranges); + } + /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index a1195a9d2a65..1691949f30b6 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -20,211 +20,219 @@ from tvm.script import tir as T -def _check(original, transformed): - func = original - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.FlattenBuffer()(mod) - mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) - - -@T.prim_func -def elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i in T.serial(0, 16): - B_new = T.allocate([1, 16], "float32", "global") - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - -@T.prim_func -def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - for i in T.serial(0, 16): - B_new = T.allocate([16], "float32", "global") - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -@T.prim_func -def gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([1, 16], "float32", "local") - for j in range(0, 16): - B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 - for j in range(0, 16): - C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 - - -@T.prim_func -def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") - for j in range(0, 16): - B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 - for j in range(0, 16): - C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 - - -@T.prim_func -def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, (n, m), "float32") - C = T.match_buffer(c, (n, m), "float32") - - for i in range(0, n): - B = T.allocate([m], "float32", "global") - for j in range(0, m): - B[j] = A[i, j] + 1.0 - for j in range(0, m): - C[i, j] = B[j] * 2.0 - - -@T.prim_func -def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, n * m, "float32") - C = T.match_buffer(c, n * m, "float32") - T.preflattened_buffer(A, (n, m), "float32", data=A.data) - T.preflattened_buffer(C, (n, m), "float32", data=C.data) - - for i in range(0, n): - B = T.allocate([m], "float32", "global") - for j in range(0, m): - B[j] = A[i * m + j] + 1.0 - for j in range(0, m): - C[i * m + j] = B[j] * 2.0 - - -@T.prim_func -def multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (4, 32), "float32") - D = T.match_buffer(d, (4, 32), "float32") - - for i, j in T.grid(4, 32): - B = T.allocate((4, 32), "float32", scope="global") - C = T.allocate((4, 32), "float32", scope="global") - B[i, j] = A[i, j] + 1.0 - C[i, j] = A[i, j] + B[i, j] - D[i, j] = C[i, j] * 2.0 - - -@T.prim_func -def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, 128, "float32") - D = T.match_buffer(d, 128, "float32") - T.preflattened_buffer(A, (4, 32), "float32", data=A.data) - T.preflattened_buffer(D, (4, 32), "float32", data=D.data) - - for i, j in T.grid(4, 32): - B = T.allocate([128], "float32", "global") - C = T.allocate([128], "float32", "global") - B[i * 32 + j] = A[i * 32 + j] + 1.0 - C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] - D[i * 32 + j] = C[i * 32 + j] * 2.0 - - -@T.prim_func -def strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i0 in T.serial(4): - B = T.allocate([4, 17], "float32", "global") - B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) - for i1, j in T.grid(4, 16): - B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 - for i1, j in T.grid(4, 16): - C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 - - -@T.prim_func -def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (256,), "float32") - C = T.match_buffer(c, (256,), "float32") - T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) - T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) - for i0 in T.serial(0, 4): - B_new = T.allocate([68], "float32", "global") - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 - - -@T.prim_func -def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: - for i0 in T.serial(10): - b[i0] = a[i0] - - -@T.prim_func -def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: - T.preflattened_buffer(a, [10], dtype="bool", data=a.data) - T.preflattened_buffer(b, [10], dtype="bool", data=b.data) - # body - for i0 in T.serial(10): - b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") - - -def test_elementwise(): - _check(elementwise_func, flattened_elementwise_func) - +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.transform.Sequential( + [ + tvm.tir.transform.FlattenBuffer(), + tvm.tir.transform.Simplify(), + ] + ) -def test_gpu_workload(): - _check(gpu_func, flattened_gpu_func) +class TestElementwise(BaseCompare): + """2-d buffers are flattened to 1-d""" -def test_symbolic_shape(): - _check(symbolic_func, flattened_symbolic_func) - - -def test_multi_alloc(): - _check(multi_alloc_func, flattened_multi_alloc_func) - - -def test_strided_buffer(): - _check(strided_buffer_func, flattened_strided_buffer_func) - + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i in T.serial(0, 16): + B_new = T.alloc_buffer([1, 16], "float32") + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 -def test_lower_te(): - x = te.placeholder((1,)) - y = te.compute((1,), lambda i: x[i] + 2) - s = te.create_schedule(y.op) - orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) - mod = tvm.tir.transform.FlattenBuffer()(orig_mod) - tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + for i in T.serial(0, 16): + B_new = T.alloc_buffer([16], "float32") + for j in T.serial(0, 16): + B_new[j] = A[((i * 16) + j)] + 1.0 + for j in T.serial(0, 16): + C[((i * 16) + j)] = B_new[j] * 2.0 + + +class TestGPU(BaseCompare): + """Buffers allocated inside GPU-specific constructs are ignored. + + These are assumed to be deliberate on the part of the + schedule-writer, and are left as-is. + """ + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.alloc_buffer([1, 16], "float32", scope="local") + for j in range(0, 16): + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.alloc_buffer([16], "float32", scope="local") + for j in range(0, 16): + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 + for j in range(0, 16): + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 + + +class TestSymbolic(BaseCompare): + """Dynamically-sized arrrays are flattened""" + + def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + B = T.alloc_buffer([m], "float32") + for j in range(0, m): + B[j] = A[i, j] + 1.0 + for j in range(0, m): + C[i, j] = B[j] * 2.0 + + def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, n * m, "float32") + C = T.match_buffer(c, n * m, "float32") + T.preflattened_buffer(A, (n, m), "float32", data=A.data) + T.preflattened_buffer(C, (n, m), "float32", data=C.data) + + for i in range(0, n): + B = T.alloc_buffer([m], "float32") + for j in range(0, m): + B[j] = A[i * m + j] + 1.0 + for j in range(0, m): + C[i * m + j] = B[j] * 2.0 + + +class TestMultiAlloc(BaseCompare): + """If multiple allocations occur, all are flattened.""" + + def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]): + for i, j in T.grid(4, 32): + B = T.alloc_buffer((4, 32), "float32", scope="global") + C = T.alloc_buffer((4, 32), "float32", scope="global") + B[i, j] = A[i, j] + 1.0 + C[i, j] = A[i, j] + B[i, j] + D[i, j] = C[i, j] * 2.0 + + def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): + T.preflattened_buffer(A, (4, 32), "float32", data=A.data) + T.preflattened_buffer(D, (4, 32), "float32", data=D.data) + + for i, j in T.grid(4, 32): + B = T.alloc_buffer([128], "float32") + C = T.alloc_buffer([128], "float32") + B[i * 32 + j] = A[i * 32 + j] + 1.0 + C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] + D[i * 32 + j] = C[i * 32 + j] * 2.0 + + +class TestStrided(BaseCompare): + """Indices for flattened buffers use the specified striding.""" + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i0 in T.serial(4): + B = T.alloc_buffer([4, 17], "float32") + B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + for i1, j in T.grid(4, 16): + B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 + for i1, j in T.grid(4, 16): + C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) + T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) + for i0 in T.serial(0, 4): + B_new = T.alloc_buffer([68], "float32") + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 + + +class TestBoolean(BaseCompare): + """Boolean buffers should be replaced by a backing int8 array""" + + def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None: + for i0 in T.serial(10): + B[i0] = A[i0] + + def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None: + T.preflattened_buffer(A, [10], dtype="bool", data=A.data) + T.preflattened_buffer(B, [10], dtype="bool", data=B.data) + for i0 in T.serial(10): + B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") + + +class TestLowerTE(BaseCompare): + """FlattenBuffer should do nothing on TE-based functions""" + + def before(self): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + return mod["main"] + + expected = before + + +class TestFlattenInsideBlock(BaseCompare): + """Flattening access inside a block flattens the accessed region.""" + + def before(): + A = T.alloc_buffer([32, 32]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.evaluate(A[i, j]) + + def expected(): + A = T.alloc_buffer([1024]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.evaluate(A[i * 32 + j]) + + +class TestNoChangeTo2DPhysicalBuffer(BaseCompare): + """Flattening preserves axis separators.""" + + def before(): + A = T.alloc_buffer([32, 32], axis_separators=[1]) + for i, j in T.grid(32, 32): + T.evaluate(A[i, j]) + expected = before + + +class TestFlattenWithAxisSeparators(BaseCompare): + """Flattening preserves axis separators""" + + def before(): + A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0, i1, i2, i3, i4, i5]) -def test_boolean_handling(): - _check(boolean_handling_before, boolean_handling_after) + def expected(): + A = T.alloc_buffer([30, 1001], axis_separators=[1]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) if __name__ == "__main__": From b123adb1202e88d13e9b0e49ef6aed6841ff1257 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 09:51:39 -0500 Subject: [PATCH 02/11] Add unit test to validate non-flat memory after tvm.lower --- .../test_tir_transform_flatten_buffer.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 1691949f30b6..b2d0f8f093b7 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -235,5 +235,36 @@ def expected(): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) +def test_lower_2d_physical_memory(): + """Axis separators should preserve 2-d buffers through lowering. + + A catch-all test to ensure that defining axis_separators is + sufficient to maintain non-flat buffer descriptions through all + lowering steps. + """ + + # This test doesn't use CompareBeforeAfter, because the after step + # is not currently expressible in TVMScript. This test can be + # re-written after https://github.com/apache/tvm/pull/12412. + + @T.prim_func + def func(): + buf = T.alloc_buffer( + [1, 1], + dtype="int32", + scope="global", + axis_separators=[1], + ) + buf[0, 0] = 0 + + lowered = tvm.lower(func)["main"] + assert isinstance(lowered.body, tvm.tir.Allocate) + assert list(lowered.body.extents) == [1, 1], ( + "Non-flat buffer allocations, " + "marked by axis_separators, " + "flattened to flat memory allocation." + ) + + if __name__ == "__main__": tvm.testing.main() From 6bc203cce637dc744d5741d3b544753dc5c13c7b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 10:05:13 -0500 Subject: [PATCH 03/11] Explicitly write T.reads for test on BufferRegion updates --- tests/python/unittest/test_tir_transform_flatten_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index b2d0f8f093b7..bd6ad1d680ae 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -201,12 +201,14 @@ def before(): A = T.alloc_buffer([32, 32]) for i, j in T.grid(32, 32): with T.block("block"): + T.reads(A[i, j]) T.evaluate(A[i, j]) def expected(): A = T.alloc_buffer([1024]) for i, j in T.grid(32, 32): with T.block("block"): + T.reads(A[i * 32 + j]) T.evaluate(A[i * 32 + j]) From a4b7573dfff9d6c2135bc979db977bdb262bacbf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 10:10:31 -0500 Subject: [PATCH 04/11] Update incorrect docstring for test --- tests/python/unittest/test_tir_transform_flatten_buffer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index bd6ad1d680ae..0483c1b7901a 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -52,11 +52,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): class TestGPU(BaseCompare): - """Buffers allocated inside GPU-specific constructs are ignored. - - These are assumed to be deliberate on the part of the - schedule-writer, and are left as-is. - """ + """Buffer flattening may have indices based on GPU thread vars""" def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): From 97fcb29a294ad61d52086bc6b2d012009d0b9093 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 11:03:16 -0500 Subject: [PATCH 05/11] Use DeclBuffer information in FlattenBuffer The DeclBuffer node can be inserted during LowerOpaqueBlock, then provide the missing Buffer information required to flatten the allocation. --- src/driver/driver_api.cc | 2 +- src/tir/transforms/flatten_buffer.cc | 55 ++++++++++++++++++++---- src/tir/transforms/lower_opaque_block.cc | 1 + 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ac02f151965b..e528686d967d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -203,8 +203,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index a14331ccdc64..1d0123873e07 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -21,6 +21,7 @@ * \file flatten_buffer.cc */ +#include #include #include @@ -89,18 +90,56 @@ class BufferFlattener : public StmtExprMutator { auto writer = alloc.CopyOnWrite(); writer->dtype = DataType::Int(8); } - // Handle multi-dimension allocations + if (alloc->extents.size() == 1) { + // No flattening required for buffers that are already flat return std::move(alloc); - } else { - Array flat_extent(static_cast(1), 1); - for (size_t i = 0; i < alloc->extents.size(); i++) { - flat_extent.Set(0, flat_extent[0] * alloc->extents[i]); + } + + if (auto* decl_buffer = alloc->body.as(); + decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) { + // N-d buffer, use the DeclBuffer inside to determine how it + // should be flattened. + auto& buffer = decl_buffer->buffer; + bool matching_buffer = [&]() { + if (alloc->dtype != buffer->dtype) { + return false; + } + if (alloc->extents.size() != buffer->shape.size()) { + return false; + } + ExprDeepEqual expr_equal; + for (size_t i = 0; i < alloc->extents.size(); i++) { + if (!expr_equal(alloc->extents[i], buffer->shape[i])) { + return false; + } + } + return true; + }(); + + if (matching_buffer) { + Buffer flattened = GetFlattenedBuffer(buffer); + + auto n = alloc.CopyOnWrite(); + n->body = DeclBuffer(flattened, std::move(decl_buffer->body)); + n->extents = flattened->shape; + return std::move(alloc); + } else { + ICHECK(decl_buffer->buffer->axis_separators.empty()) + << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be " + "flattened to 1-d physical memory"; } - auto n = alloc.CopyOnWrite(); - n->extents = flat_extent; - return std::move(alloc); } + + // Fallback, this is an allocation without a matching DeclBuffer + PrimExpr flat_extent = 1; + for (const auto& dim : alloc->extents) { + flat_extent *= dim; + } + + auto n = alloc.CopyOnWrite(); + n->extents = {flat_extent}; + return std::move(alloc); } Buffer GetFlattenedBuffer(Buffer buf) { diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 69d8787aa1a1..05af9962646e 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator { new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } + body = DeclBuffer(buffer, std::move(body)); body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body)); } return body; From 2bf7bf280766170f262b8b10326cc4545b4ff8d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 11:27:33 -0500 Subject: [PATCH 06/11] Use T.allocate in unit tests With the insertion of `DeclBuffer` nodes, `LowerOpaqueBlock` no longer needs to be before `FlattenBuffer`, and has been moved back to its original position. Revering the tests to use `T.allocate` instead of `T.alloc_buffer` more closely represents the functions as they are being lowered. --- .../test_tir_transform_flatten_buffer.py | 78 +++++++++++++++---- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 0483c1b7901a..5356c7001e36 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -34,7 +34,7 @@ class TestElementwise(BaseCompare): def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for i in T.serial(0, 16): - B_new = T.alloc_buffer([1, 16], "float32") + B_new = T.decl_buffer([1, 16], "float32", "global") for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -44,7 +44,38 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new = T.alloc_buffer([16], "float32") + B_new = T.decl_buffer(16, "float32", "global") + for j in T.serial(0, 16): + B_new[j] = A[((i * 16) + j)] + 1.0 + for j in T.serial(0, 16): + C[((i * 16) + j)] = B_new[j] * 2.0 + + +class TestElementwiseWithoutDeclBuffer(BaseCompare): + """2-d buffers are flattened to 1-d + + Like TestElementwise, but the TIR doesn't have the DeclBuffer + node. The T.buffer_decl declaration applies only during the + parsing the TVMScript, and doesn't occur in the TIR itself. In + this case, the allocation should be assumed to be targeting flat + memory, and should be flattened to a 1-d allocation. + """ + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i in T.serial(0, 16): + B_new_data = T.allocate([1, 16], "float32", "global") + B_new = T.buffer_decl([1, 16], "float32", data=B_new_data) + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + for i in T.serial(0, 16): + B_new_data = T.allocate(16, "float32", "global") + B_new = T.buffer_decl(16, "float32", data=B_new_data) for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -55,7 +86,6 @@ class TestGPU(BaseCompare): """Buffer flattening may have indices based on GPU thread vars""" def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): - i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") i2 = T.env_thread("vthread") @@ -63,7 +93,7 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.alloc_buffer([1, 16], "float32", scope="local") + B = T.decl_buffer([1, 16], "float32", "local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -80,7 +110,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.alloc_buffer([16], "float32", scope="local") + B = T.decl_buffer([16], "float32", "local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -95,7 +125,7 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.alloc_buffer([m], "float32") + B = T.decl_buffer([m], "float32", "global") for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -108,7 +138,7 @@ def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): - B = T.alloc_buffer([m], "float32") + B = T.decl_buffer([m], "float32", "global") for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -120,8 +150,8 @@ class TestMultiAlloc(BaseCompare): def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]): for i, j in T.grid(4, 32): - B = T.alloc_buffer((4, 32), "float32", scope="global") - C = T.alloc_buffer((4, 32), "float32", scope="global") + B = T.decl_buffer((4, 32), "float32", scope="global") + C = T.decl_buffer((4, 32), "float32", scope="global") B[i, j] = A[i, j] + 1.0 C[i, j] = A[i, j] + B[i, j] D[i, j] = C[i, j] * 2.0 @@ -131,8 +161,8 @@ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): T.preflattened_buffer(D, (4, 32), "float32", data=D.data) for i, j in T.grid(4, 32): - B = T.alloc_buffer([128], "float32") - C = T.alloc_buffer([128], "float32") + B = T.decl_buffer([128], "float32", "global") + C = T.decl_buffer([128], "float32", "global") B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -143,7 +173,7 @@ class TestStrided(BaseCompare): def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for i0 in T.serial(4): - B = T.alloc_buffer([4, 17], "float32") + B = T.decl_buffer([4, 17], "float32", "global") B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) for i1, j in T.grid(4, 16): B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 @@ -154,7 +184,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): - B_new = T.alloc_buffer([68], "float32") + B_new = T.decl_buffer([68], "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -173,6 +203,7 @@ def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None: def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None: T.preflattened_buffer(A, [10], dtype="bool", data=A.data) T.preflattened_buffer(B, [10], dtype="bool", data=B.data) + # body for i0 in T.serial(10): B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") @@ -219,7 +250,7 @@ def before(): expected = before -class TestFlattenWithAxisSeparators(BaseCompare): +class TestFlattenAllocBufferWithAxisSeparators(BaseCompare): """Flattening preserves axis separators""" def before(): @@ -233,6 +264,25 @@ def expected(): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) +class TestFlattenDeclBufferWithAxisSeparators(BaseCompare): + """Flattening preserves axis separators + + Like TestFlattenAllocBufferWithAxisSeparators, but the allocations + is done using Allocate/DeclBuffer, rather than through + BlockNode::alloc_buffers. + """ + + def before(): + A = T.decl_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0, i1, i2, i3, i4, i5]) + + def expected(): + A = T.decl_buffer([30, 1001], axis_separators=[1]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) + + def test_lower_2d_physical_memory(): """Axis separators should preserve 2-d buffers through lowering. From aaa47b825fde6a9fe348eaf6e35fe7b2456fca6c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 11:56:12 -0500 Subject: [PATCH 07/11] Fix usage of T.decl_buffer in updated tests --- .../test_tir_transform_flatten_buffer.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 5356c7001e36..043ee99eaa38 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -34,7 +34,7 @@ class TestElementwise(BaseCompare): def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for i in T.serial(0, 16): - B_new = T.decl_buffer([1, 16], "float32", "global") + B_new = T.decl_buffer([1, 16], "float32") for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -44,7 +44,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new = T.decl_buffer(16, "float32", "global") + B_new = T.decl_buffer([16], "float32") for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -74,7 +74,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new_data = T.allocate(16, "float32", "global") + B_new_data = T.allocate([16], "float32", "global") B_new = T.buffer_decl(16, "float32", data=B_new_data) for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 @@ -93,7 +93,7 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.decl_buffer([1, 16], "float32", "local") + B = T.decl_buffer([1, 16], "float32", scope="local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -110,7 +110,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.decl_buffer([16], "float32", "local") + B = T.decl_buffer([16], "float32", scope="local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -125,7 +125,7 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.decl_buffer([m], "float32", "global") + B = T.decl_buffer([m], "float32") for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -138,7 +138,7 @@ def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): - B = T.decl_buffer([m], "float32", "global") + B = T.decl_buffer([m], "float32") for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -161,8 +161,8 @@ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): T.preflattened_buffer(D, (4, 32), "float32", data=D.data) for i, j in T.grid(4, 32): - B = T.decl_buffer([128], "float32", "global") - C = T.decl_buffer([128], "float32", "global") + B = T.decl_buffer([128], "float32") + C = T.decl_buffer([128], "float32") B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -173,7 +173,7 @@ class TestStrided(BaseCompare): def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for i0 in T.serial(4): - B = T.decl_buffer([4, 17], "float32", "global") + B = T.decl_buffer([4, 17], "float32") B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) for i1, j in T.grid(4, 16): B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 @@ -184,7 +184,7 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): - B_new = T.decl_buffer([68], "float32", "global") + B_new = T.decl_buffer([68], "float32") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 From ea0a87ada70e640102cd2c8864ee0877787f60f7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 1 Sep 2022 10:21:45 -0500 Subject: [PATCH 08/11] Update LowerOpaqueBuffer to expect the DeclBuffer nodes --- .../test_tir_transform_lower_opaque_block.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index f8f3e3a5aced..824cef174055 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -54,8 +54,7 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data) + B_new = T.decl_buffer(shape=[1, 16], dtype="float32") for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -97,8 +96,7 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B_data = T.allocate([1, 16], "float32", "local") - B = T.buffer_decl(shape=[1, 16], dtype="float32", scope="local", data=B_data) + B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -133,8 +131,7 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B_data = T.allocate([m], "float32", "global") - B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) + B = T.decl_buffer(shape=[m], dtype="float32") for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -207,10 +204,8 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B_data = T.allocate((32,), "float32", "global") - B = T.buffer_decl(shape=(32,), dtype="float32", data=B_data) - C_data = T.allocate((32,), "float32", "global") - C = T.buffer_decl(shape=(32,), dtype="float32", data=C_data) + B = T.decl_buffer(shape=(32,), dtype="float32") + C = T.decl_buffer(shape=(32,), dtype="float32") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 @@ -246,12 +241,11 @@ def transformed_strided_buffer_func( # body for i0 in T.serial(4): B_data = T.allocate([4, 17], "float32", "global") - B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data) - B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1], data=B_data) for i1, j in T.grid(4, 16): - B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) + B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) for i1, j in T.grid(4, 16): - C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2) + C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) @T.prim_func From 12bba285db0ed4f52fc1c8e6ee4b9eb3a92b4905 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 1 Sep 2022 10:55:04 -0500 Subject: [PATCH 09/11] Strip DeclBuffer annotation in FlattenBuffer The DeclBuffer annotations aren't yet supported in all passes. This restricts them to being introduced in LowerOpaqueBuffer, then immediately removed in FlattenBuffer. --- src/tir/transforms/flatten_buffer.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 1d0123873e07..92afa0c15fa8 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -121,7 +121,13 @@ class BufferFlattener : public StmtExprMutator { Buffer flattened = GetFlattenedBuffer(buffer); auto n = alloc.CopyOnWrite(); - n->body = DeclBuffer(flattened, std::move(decl_buffer->body)); + // TODO(rfc-70): Update the DeclBuffer node instead of + // stripping it out. Stripping it out in the current + // implementation as not all lowering passes support + // DeclBuffer. + // + // n->body = DeclBuffer(flattened, std::move(decl_buffer->body)); + n->body = std::move(decl_buffer->body); n->extents = flattened->shape; return std::move(alloc); } else { From 2a35839bbd62589bea229fabe29746c56e8a0236 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 10:25:08 -0500 Subject: [PATCH 10/11] Strip out all DeclBuffer nodes in FlattenBuffer --- src/tir/transforms/flatten_buffer.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 92afa0c15fa8..5441120491c6 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -93,6 +93,14 @@ class BufferFlattener : public StmtExprMutator { if (alloc->extents.size() == 1) { // No flattening required for buffers that are already flat + + // TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it + // out in the current implementation as not all lowering passes + // support DeclBuffer. + if (auto* decl_buffer = alloc->body.as()) { + alloc.CopyOnWrite()->body = std::move(decl_buffer->body); + } + return std::move(alloc); } From 018f4f8d9f86b32e750c475fa98d9a5ac0ecc6c0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 10:25:27 -0500 Subject: [PATCH 11/11] Update unit tests to remove expectation of DeclBuffer nodes --- .../test_tir_transform_flatten_buffer.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 043ee99eaa38..870208499e7a 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -44,7 +44,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new = T.decl_buffer([16], "float32") + B_new_data = T.allocate([16], "float32", scope="global") + B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data) for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -110,7 +111,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.decl_buffer([16], "float32", scope="local") + B_data = T.allocate([16], "float32", scope="local") + B = T.buffer_decl([16], "float32", scope="local", data=B_data) for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -138,7 +140,8 @@ def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): - B = T.decl_buffer([m], "float32") + B_data = T.allocate([m], "float32", scope="global") + B = T.buffer_decl([m], "float32", scope="global", data=B_data) for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -161,8 +164,10 @@ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): T.preflattened_buffer(D, (4, 32), "float32", data=D.data) for i, j in T.grid(4, 32): - B = T.decl_buffer([128], "float32") - C = T.decl_buffer([128], "float32") + B_data = T.allocate([128], "float32", scope="global") + B = T.buffer_decl([128], "float32", scope="global", data=B_data) + C_data = T.allocate([128], "float32", scope="global") + C = T.buffer_decl([128], "float32", scope="global", data=C_data) B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -184,7 +189,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): - B_new = T.decl_buffer([68], "float32") + B_new_data = T.allocate([68], "float32", scope="global") + B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data) for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -278,7 +284,10 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A = T.decl_buffer([30, 1001], axis_separators=[1]) + A_data = T.allocate([30, 1001], dtype="float32", scope="global") + A = T.buffer_decl( + [30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data + ) for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])