Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
Region ret;
auto kv = buffer_access_map_.find(buffer.get());
CHECK(kv != buffer_access_map_.end())
<< "The requested buffer is not contained in the provided stmt body.";
if (kv == buffer_access_map_.end()) {
LOG(WARNING) << "[arith::BufferDomainTouched] "
<< "The requested buffer is not contained in the provided stmt body: " << buffer;
return ret;
}

Region ret;
Range none;
BufferTouches bounds;
if (consider_loads && consider_stores) {
Expand Down Expand Up @@ -131,13 +134,16 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

private:
template <typename ArrayType>
void Touch(BufferTouches* bounds, const ArrayType& args) const {
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
}
}
}

Expand Down
63 changes: 38 additions & 25 deletions tests/python/unittest/test_arith_domain_touched.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,36 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T


@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
m = T.var("int32")
n = T.int32(100)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i, j in T.grid(n, m):
A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]


@T.prim_func
def vector_func(a: T.handle, b: T.handle):
n = T.var("int32")
m = T.int32(128)
A = T.match_buffer(a, (n, m), name="A")
B = T.match_buffer(b, (n, m), name="B")

for i in T.serial(n):
for j in T.vectorized(m):
A[i, j] = A[i, j] + B[i, j]


def test_domain_touched():
i = te.var("i")
j = te.var("j")
n = tvm.runtime.convert(100)
m = te.var("m")

a = tvm.tir.decl_buffer((n, m), name="a")
b = tvm.tir.decl_buffer((n, m), name="b")

ir = tvm.tir.For(
i,
0,
n,
tvm.tir.ForKind.SERIAL,
tvm.tir.For(
j,
0,
m,
tvm.tir.ForKind.SERIAL,
tvm.tir.BufferStore(
a,
tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]),
[i, j],
),
),
)
func = scalar_func
a, b = [func.buffer_map[var] for var in func.params]
ir = func.body

a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)

Expand Down Expand Up @@ -78,5 +80,16 @@ def test_domain_touched():
assert len(b_domain_w) == 0


def test_domain_touched_vector():
func = tvm.lower(vector_func)["main"]
a, b = [func.buffer_map[var] for var in func.params]

assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, True)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128
assert tvm.arith._ffi_api.DomainTouched(func.body, b, True, False)[0].extent.value == 128


if __name__ == "__main__":
test_domain_touched()