Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ class CacheLocDetector : public StmtVisitor {
if (visited_block_ && visited_related_ && loc_pos_ == -1) {
// The offset of insert position from the block
loc_pos_ = i;
return;
break;
} else if (visited_related_) {
// If meet the target consumer, stop searching
visited_block_ = visited_block_ || previous_visited_block;
return;
break;
}
}
visited_block_ = visited_block_ || previous_visited_block;
}

void VisitStmt_(const BlockNode* block) final {
Expand Down Expand Up @@ -1078,7 +1078,7 @@ class ReIndexRewriter : public StmtExprMutator {
Region region_;
};

void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) {
void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) {
class NotRegionCoverError : public ScheduleError {
public:
explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {}
Expand All @@ -1095,12 +1095,16 @@ The region cover property require to hold for every of its child blocks
IRModule mod_;
Block block_;
};
BlockScope scope = self->GetBlockScope(scope_root);
for (const auto& kv : scope->dst2deps) {
const StmtSRef& consumer_block_sref = kv.first;
if (!self->block_info.at(consumer_block_sref).region_cover) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
throw NotRegionCoverError(self->mod, GetRef<Block>(block));

for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) {
const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
for (const BufferRegion& region : child_block->reads) {
if (region->buffer.same_as(read_buffer)) {
if (!self->block_info.at(child_block_sref).region_cover) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
throw NotRegionCoverError(self->mod, GetRef<Block>(block));
}
}
}
}
}
Expand Down Expand Up @@ -1129,7 +1133,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
// Check required region cover for cache_read
CheckRegionCover(self, scope_sref);
CheckRegionCover(self, scope_sref, read_buffer);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);

// Step 2. Create CacheStageInfo
Expand Down Expand Up @@ -1281,7 +1285,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);

// Check 3. Check required region cover for cache_read
CheckRegionCover(self, scope_sref);
CheckRegionCover(self, scope_sref, buffer);

// Check 4. Check if target block both read & write target buffer.
const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref);
Expand Down Expand Up @@ -1318,6 +1322,8 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
BlockInfo& block_info_read = self->block_info[result_block_sref];
block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info_read.region_cover = true;
block_info_read.scope->stage_pipeline = false;
results_block_sref.push_back(result_block_sref);

// Do cache write
Expand Down
76 changes: 76 additions & 0 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,33 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0


@T.prim_func
def func_nested_seq(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))

for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = 2.0
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = 1.0
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = A[vi, vj] + B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0


@T.prim_func
def access_under_scope(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
Expand Down Expand Up @@ -250,6 +277,47 @@ def inplace_call(data_io: T.Buffer[(64), "int32"]):
T.evaluate(T.call_extern("call_impl", data_io.data, dtype=""))


@T.prim_func
def cache_read_nested_seq_target(
B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
A = T.alloc_buffer([128, 128], dtype="float32")
A_global = T.alloc_buffer([128, 128], dtype="float32")
for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads()
T.writes(A[vi, vj])
A[vi, vj] = T.float32(2)
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads()
T.writes(B[vi, vj])
B[vi, vj] = T.float32(1)
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads(A[vi, vj], B[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + B[vi, vj]
for ax0, ax1 in T.grid(128, 128):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_global[v0, v1])
A_global[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = A_global[vi, vj] * T.float32(2)


########## Expected function after cache_read ##########


Expand Down Expand Up @@ -989,6 +1057,14 @@ def test_cache_inplace():
verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask)


def test_cache_read_nested_seq(use_block_name):
sch = tir.Schedule(func_nested_seq, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c])
tvm.ir.assert_structural_equal(cache_read_nested_seq_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_nested_seq)


########## Testcases for cache_write ##########


Expand Down