Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ struct BlockVarDomainInfo {
analyzer->CanProveEqual(bound.max(), intersect.max())) {
dom = bound;
bound = arith::IntSet::Nothing();
} else if (is_const_int(intersect.min()) && is_const_int(intersect.max())) {
// if the bound induce constant iter range, merge bound to loop domain
dom = intersect;
bound = arith::IntSet::Nothing();
}
}
};
Expand Down
9 changes: 0 additions & 9 deletions src/tir/schedule/primitive/decompose_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,6 @@ class DecomposePaddingBlockReplacer : public StmtMutator {
return std::move(new_loop);
}

Stmt VisitStmt_(const SeqStmtNode* seq) final {
Array<Stmt> new_stmts;
new_stmts.reserve(seq->seq.size());
for (const Stmt& old_stmt : seq->seq) {
new_stmts.push_back(VisitStmt(old_stmt));
}
return SeqStmt::Flatten(new_stmts);
}

private:
const ReplaceDesc& desc_;
};
Expand Down
74 changes: 74 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,5 +1915,79 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle):
)


def test_compute_at_sliced_concatenate():
@T.prim_func
def before():
X = T.alloc_buffer((1, 16, 28, 64), "float32")
Y = T.alloc_buffer((1, 32, 28, 64), "float32")
Z = T.alloc_buffer((1, 53, 28, 64), "float32")
Concat = T.alloc_buffer((1, 101, 28, 64), "float32")
Slice = T.alloc_buffer((1, 87, 28, 64), "float32")
for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 64):
with T.block("compute"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
X[v_ax0, v_ax1, v_ax2, v_ax3] = 1.0
for ax0, ax1, ax2, ax3 in T.grid(1, 101, 28, 64):
with T.block("T_concat"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
85 <= v_ax1,
X[v_ax0, v_ax1 - 85, v_ax2, v_ax3],
T.if_then_else(
53 <= v_ax1,
Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3],
Z[v_ax0, v_ax1, v_ax2, v_ax3],
),
)
for ax0, ax1, ax2, ax3 in T.grid(1, 87, 28, 64):
with T.block("T_strided_slice"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3]

@T.prim_func
def expect():
X = T.alloc_buffer((1, 16, 28, 64))
Y = T.alloc_buffer((1, 32, 28, 64))
Z = T.alloc_buffer((1, 53, 28, 64))
Concat = T.alloc_buffer((1, 101, 28, 64))
Slice = T.alloc_buffer((1, 87, 28, 64))
for ax0 in range(1):
for ax0_1, ax1, ax2 in T.grid(2, 28, 64):
with T.block("compute"):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(16, ax0_1)
v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2])
X[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(1)
for ax0_1, ax1, ax2 in T.grid(87, 28, 64):
with T.block("T_concat"):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(101, ax0_1)
v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2])
Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
85 <= v_ax1,
X[v_ax0, v_ax1 - 85, v_ax2, v_ax3],
T.if_then_else(
53 <= v_ax1,
Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3],
Z[v_ax0, v_ax1, v_ax2, v_ax3],
),
)
for ax1, ax2, ax3 in T.grid(87, 28, 64):
with T.block("T_strided_slice"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3]

sch = tir.Schedule(before, debug_mask="all")
blk1 = sch.get_block("compute")
blk2 = sch.get_block("T_concat")
blk3 = sch.get_block("T_strided_slice")
loop = sch.get_loops(blk3)[0]
sch.compute_at(blk2, loop)
sch.compute_at(blk1, loop)
after = sch.mod["main"]
assert_structural_equal_ignore_global_symbol(expect, after)
verify_trace_roundtrip(sch=sch, mod=before)


if __name__ == "__main__":
tvm.testing.main()