Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 76 additions & 36 deletions python/tvm/tir/schedule/schedule.py

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1241,11 +1241,10 @@ struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
Integer buffer_index_type) {
PythonAPICall py("reindex");
py.Input("block", block);
py.Input("buffer_index", buffer_index);
py.Input("buffer_index_type", '"' +
std::string(BufferIndexType2Str(
static_cast<BufferIndexType>(buffer_index_type->value))) +
'"');
std::ostringstream os;
os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
<< "\", " << buffer_index << ")";
py.Input("buffer", os.str());
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
94 changes: 48 additions & 46 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None:

########## Testcases for cache_read ##########

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})

def test_cache_read_elementwise():

def test_cache_read_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
cached_a = sch.cache_read(block_b, 0, "global")
cached_b = sch.cache_read(block_c, 0, "local")
cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global")
cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local")
assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
Expand All @@ -756,87 +758,87 @@ def test_cache_read_elementwise():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_read_under_scope():
def test_cache_read_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "local")
sch.cache_read(block_c, 0, "global")
tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=access_under_scope)


def test_cache_read_opaque_access():
def test_cache_read_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block = sch.get_block("load_store")
block = "load_store" if use_block_name else sch.get_block("load_store")
sch.cache_read(block, 0, "global")
tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_cache_read_location():
def test_cache_read_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_read():
def test_continuous_cache_read(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_c = sch.get_block("C")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "shared")
sch.cache_read(block_c, 0, "local")
tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_read_with_block_predicate():
def test_cache_read_with_block_predicate(use_block_name):
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("consumer")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_read(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)


def test_cache_read_non_int32_shape():
def test_cache_read_non_int32_shape(use_block_name):
sch = tir.Schedule(elementwise_shape_int64, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)


def test_cache_read_fail_multi_producer():
def test_cache_read_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "global")


def test_cache_read_fail_index_out_of_bound():
def test_cache_read_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 1, "global")


def test_cache_read_fail_invalid_storage_scope():
def test_cache_read_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "test_scope")


########## Testcases for cache_write ##########


def test_cache_write_elementwise():
def test_cache_write_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
cached_b = sch.cache_write(block_b, 0, "local")
cached_c = sch.cache_write(block_c, 0, "global")
cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local")
cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global")
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
Expand All @@ -845,10 +847,10 @@ def test_cache_write_elementwise():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_write_under_scope():
def test_cache_write_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_a = sch.get_block("A")
block_b = sch.get_block("B")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
block_scope = sch.get_block("scope")
sch.cache_write(block_a, 0, "local")
sch.cache_write(block_b, 0, "global")
Expand All @@ -857,70 +859,70 @@ def test_cache_write_under_scope():
verify_trace_roundtrip(sch=sch, mod=access_under_scope)


def test_cache_write_opaque_access():
def test_cache_write_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block_store = sch.get_block("load_store")
block_opaque = sch.get_block("opaque")
block_match_buffer = sch.get_block("match_buffer")
block_store = "load_store" if use_block_name else sch.get_block("load_store")
block_opaque = "opaque" if use_block_name else sch.get_block("opaque")
block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer")
sch.cache_write(block_store, 0, "global")
sch.cache_write(block_opaque, 0, "global")
sch.cache_write(block_match_buffer, 0, "global")
tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_cache_write_location():
def test_cache_write_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = sch.get_block("A")
block_a = "A" if use_block_name else sch.get_block("A")
sch.cache_write(block_a, 0, "global")
tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_write():
def test_continuous_cache_write(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_write(block_b, 0, "shared")
sch.cache_write(block_b, 0, "local")
tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_write_with_block_predicate():
def test_cache_write_with_block_predicate(use_block_name):
# cache write for intermediate buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("producer")
block = "producer" if use_block_name else sch.get_block("producer")
sch.cache_write(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
# cache write for external buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("consumer")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_write(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)


def test_cache_write_fail_multi_producer():
def test_cache_write_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_a0 = sch.get_block("A0")
block_a1 = sch.get_block("A1")
block_a0 = "A0" if use_block_name else sch.get_block("A0")
block_a1 = "A1" if use_block_name else sch.get_block("A1")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a0, 0, "global")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a1, 0, "global")


def test_cache_write_fail_index_out_of_bound():
def test_cache_write_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 1, "global")


def test_cache_write_fail_invalid_storage_scope():
def test_cache_write_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 0, "test_scope")

Expand Down
Loading