Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 25 additions & 42 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class PrimFuncSpecializer : public StmtExprMutator {
buffer_map.Set(var, new_buffer);
if (!new_buffer.same_as(buffer)) {
buffer_map_updated = true;
specializer.buffer_map_[buffer] = new_buffer;
}
}

Expand Down Expand Up @@ -116,7 +115,7 @@ class PrimFuncSpecializer : public StmtExprMutator {
Stmt VisitStmt_(const BlockNode* op) final {
// Step.0. Define buffer mappings which is allocated inside the block
Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
std::bind(&PrimFuncSpecializer::MutateBuffer, this, std::placeholders::_1));

// Step.1. Recursively visit block body
Stmt stmt = StmtExprMutator::VisitStmt_(op);
Expand All @@ -141,31 +140,21 @@ class PrimFuncSpecializer : public StmtExprMutator {
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {
return GetRef<BufferStore>(op);
} else {
auto n = CopyOnWrite(op);
n->buffer = it->second;
return Stmt(n);
auto stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto new_buffer = MutateBuffer(stmt->buffer);
if (!new_buffer.same_as(stmt->buffer)) {
stmt.CopyOnWrite()->buffer = new_buffer;
}
return std::move(stmt);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {
return GetRef<BufferLoad>(op);
} else {
auto n = make_object<BufferLoadNode>(*op);
n->buffer = it->second;
return PrimExpr(n);
auto expr = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto new_buffer = MutateBuffer(expr->buffer);
if (!new_buffer.same_as(expr->buffer)) {
expr.CopyOnWrite()->buffer = new_buffer;
}
return std::move(expr);
}

PrimExpr VisitExpr_(const VarNode* op) final {
Expand Down Expand Up @@ -198,22 +187,28 @@ class PrimFuncSpecializer : public StmtExprMutator {

private:
Buffer MutateBuffer(const Buffer& buffer) {
if (auto it = buffer_map_.find(buffer); it != buffer_map_.end()) {
return it->second;
}

Array<PrimExpr> shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); });
Array<PrimExpr> strides =
buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); });

PrimExpr elem_offset = VisitExpr(buffer->elem_offset);

if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
buffer->strides.same_as(strides)) {
return buffer;
} else {
Buffer new_buffer = buffer;
if (!buffer->elem_offset.same_as(elem_offset) || !buffer->shape.same_as(shape) ||
!buffer->strides.same_as(strides)) {
auto n = make_object<BufferNode>(*buffer.get());
n->elem_offset = std::move(elem_offset);
n->shape = std::move(shape);
n->strides = std::move(strides);
return Buffer(n);
new_buffer = Buffer(n);
}

buffer_map_[buffer] = new_buffer;
return new_buffer;
}

Range MutateRange(const Range& range) {
Expand All @@ -226,26 +221,14 @@ class PrimFuncSpecializer : public StmtExprMutator {
}
}

Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
Buffer buf = MutateBuffer(alloc_buf);
if (buf.same_as(alloc_buf)) {
return alloc_buf;
} else {
ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
buffer_map_[alloc_buf] = buf;
return buf;
}
}

BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
auto it = buffer_map_.find(buffer_region->buffer);
const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer;
Buffer buffer = MutateBuffer(buffer_region->buffer);
Array<Range> region = buffer_region->region.Map(
std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
return buffer_region;
} else {
return BufferRegion(buffer, std::move(region));
return BufferRegion(std::move(buffer), std::move(region));
}
}

Expand Down
32 changes: 26 additions & 6 deletions tests/python/unittest/test_tir_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,30 @@ def test_specialize_with_const_folding():
tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16)


def test_specialize_with_pointer_size():
"""Rewrite implicit buffers

The first appearance of a `tir::Buffer` may be at its usage in
BufferStore or BufferLoad. The specialized parameter should be
updated in these buffers as well as explicitly defined buffers.
"""

@T.prim_func
def before(a_ptr: T.handle("float32"), n: T.int64):
A = T.Buffer(n, "float32", data=a_ptr)
for i in range(n):
A[i] = 0.0

@T.prim_func
def expected(a_ptr: T.handle("float32")):
A = T.Buffer(T.int64(16), "float32", data=a_ptr)
for i in range(T.int64(16)):
A[i] = 0.0

_, n = before.params
after = before.specialize({n: tvm.tir.const(16, n.dtype)})
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
test_specialize_nothing()
test_specialize_matmul()
test_specialize_elemwise()
test_specialize_mem_copy()
test_specialize_recursive_load()
test_specialize_with_const_folding()
tvm.testing.main()