From 5b60b4fe2b640389a7c17de1e2f299ab2b06f38e Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 14 Jun 2022 15:06:32 -0700 Subject: [PATCH 1/2] [Arith] Update BufferDomainTouched to support vector access. --- src/arith/domain_touched.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 8874f4f16506..403ea47f4e61 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -65,11 +65,14 @@ class BufferTouchedDomain final : public StmtExprVisitor { } Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { + Region ret; auto kv = buffer_access_map_.find(buffer.get()); - CHECK(kv != buffer_access_map_.end()) - << "The requested buffer is not contained in the provided stmt body."; + if (kv == buffer_access_map_.end()) { + LOG(WARNING) << "[arith::BufferDomainTouched] " + << "The requested buffer is not contained in the provided stmt body: " << buffer; + return ret; + } - Region ret; Range none; BufferTouches bounds; if (consider_loads && consider_stores) { @@ -131,13 +134,16 @@ class BufferTouchedDomain final : public StmtExprVisitor { } private: - template - void Touch(BufferTouches* bounds, const ArrayType& args) const { + void Touch(BufferTouches* bounds, const Array& args) const { if (args.size() > bounds->size()) { bounds->resize(args.size()); } for (size_t i = 0; i < args.size(); ++i) { - (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_)); + if (args[i].as()) { + (*bounds)[i].emplace_back(IntSet::Vector(args[i])); + } else { + (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_)); + } } } From 71439cc7d17c6654414236d04ad32ee41f4240c8 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 14 Jun 2022 16:29:36 -0700 Subject: [PATCH 2/2] Add test checking that domain touched works on IR containing RampNodes. --- .../unittest/test_arith_domain_touched.py | 63 +++++++++++-------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index af06a038e1f7..8b982e65a055 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -16,34 +16,36 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T + + +@T.prim_func +def scalar_func(a: T.handle, b: T.handle): + m = T.var("int32") + n = T.int32(100) + A = T.match_buffer(a, (n, m), name="A") + B = T.match_buffer(b, (n, m), name="B") + + for i, j in T.grid(n, m): + A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1] + + +@T.prim_func +def vector_func(a: T.handle, b: T.handle): + n = T.var("int32") + m = T.int32(128) + A = T.match_buffer(a, (n, m), name="A") + B = T.match_buffer(b, (n, m), name="B") + + for i in T.serial(n): + for j in T.vectorized(m): + A[i, j] = A[i, j] + B[i, j] def test_domain_touched(): - i = te.var("i") - j = te.var("j") - n = tvm.runtime.convert(100) - m = te.var("m") - - a = tvm.tir.decl_buffer((n, m), name="a") - b = tvm.tir.decl_buffer((n, m), name="b") - - ir = tvm.tir.For( - i, - 0, - n, - tvm.tir.ForKind.SERIAL, - tvm.tir.For( - j, - 0, - m, - tvm.tir.ForKind.SERIAL, - tvm.tir.BufferStore( - a, - tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]), - [i, j], - ), - ), - ) + func = scalar_func + a, b = [func.buffer_map[var] for var in func.params] + ir = func.body a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) @@ -78,5 +80,16 @@ def test_domain_touched(): assert len(b_domain_w) == 0 +def test_domain_touched_vector(): + func = tvm.lower(vector_func)["main"] + a, b = [func.buffer_map[var] for var in func.params] + + assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 + assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 + assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, True)[0].extent.value == 128 + assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128 + assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128 + + if __name__ == "__main__": test_domain_touched()