Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,190 @@ def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3
)


def test_cpu_tbg():
# fmt: off
@T.prim_func
def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32")
value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32")
C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1 in T.grid(1, 1, 1, 2, 1, 6, 2):
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64):
with T.block("value_T"):
b = T.axis.spatial(1, ax0)
h = T.axis.spatial(12, i1_1 * 2 + ax1)
d = T.axis.spatial(64, ax2)
l = T.axis.spatial(128, i3_0 * 64 + ax3)
T.reads(value[b, l, h, d])
T.writes(value_T[b, h, d, l])
value_T[b, h, d, l] = value[b, l, h, d]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64):
with T.block("query_T"):
b = T.axis.spatial(1, ax0)
h = T.axis.spatial(12, i1_1 * 2 + ax1)
l = T.axis.spatial(128, i2_1 * 64 + ax2)
d = T.axis.spatial(64, ax3)
T.reads(query[b, l, h, d])
T.writes(query_T[b, h, l, d])
query_T[b, h, l, d] = query[b, l, h, d]
for i3_1 in T.serial(8):
for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2):
with T.block("C"):
b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3)
i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3)
j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3)
k = T.axis.reduce(64, i4_0 * 64 + i4_1)
T.reads(query_T[b, h, i, k], value_T[b, h, k, j])
T.writes(C_global[b, h, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
C_global[b, h, i, j] = T.float32(0)
C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8):
with T.block("C_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(12, i1_1 * 2 + ax1)
v2 = T.axis.spatial(128, i2_1 * 64 + ax2)
v3 = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3)
T.reads(C_global[v0, v1, v2, v3])
T.writes(C[v0, v1, v2, v3])
C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3]
@T.prim_func
def tbg_1(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32")
value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32")
C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32")
for i0, i1, i2, i3 in T.grid(1, 12, 128, 64):
with T.block("query_T"):
b, h, l, d = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(query[b, l, h, d])
T.writes(query_T[b, h, l, d])
query_T[b, h, l, d] = query[b, l, h, d]
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 2):
for i0_1, i1_1, i2_1, i3_1, i4_0, i0_2, i1_2, i2_2, i3_2, i4_1 in T.grid(1, 6, 2, 8, 1, 1, 2, 2, 4, 64):
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 2):
with T.block("value_T"):
b = T.axis.spatial(1, ax0)
h = T.axis.spatial(12, i1_1 * 2 + i1_2 + ax1)
d = T.axis.spatial(64, i4_1 + ax2)
l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + ax3)
T.reads(value[b, l, h, d])
T.writes(value_T[b, h, d, l])
value_T[b, h, d, l] = value[b, l, h, d]
for i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 32, 2):
with T.block("C"):
b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3)
i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3)
j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3)
k = T.axis.reduce(64, i4_0 * 64 + i4_1)
T.reads(query_T[b, h, i, k], value_T[b, h, k, j])
T.writes(C_global[b, h, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
C_global[b, h, i, j] = T.float32(0)
C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j]
for ax0, ax1, ax2, ax3 in T.grid(1, 12, 128, 64):
with T.block("C_global"):
v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
v3 = T.axis.spatial(128, i3_0 * 64 + ax3)
T.reads(C_global[v0, v1, v2, v3])
T.writes(C[v0, v1, v2, v3])
C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3]
@T.prim_func
def tbg_2(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 2, 1, 6, 2, 8):
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8):
with T.block("value_T"):
b = T.axis.spatial(1, ax0)
h = T.axis.spatial(12, i1_1 * 2 + ax1)
d = T.axis.spatial(64, ax2)
l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3)
T.reads(value[b, l, h, d])
T.writes(value_T[b, h, d, l])
value_T[b, h, d, l] = value[b, l, h, d]
for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2):
with T.block("C"):
b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3)
i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3)
j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3)
k = T.axis.reduce(64, i4_0 * 64 + i4_1)
T.reads(query[b, i, h, k], value_T[b, h, k, j])
T.writes(C[b, h, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
C[b, h, i, j] = T.float32(0)
C[b, h, i, j] = C[b, h, i, j] + query[b, i, h, k] * value_T[b, h, k, j]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 6, 2, 1]),
("SamplePerfectTile", [1, 2, 2, 32]),
("SamplePerfectTile", [2, 8, 4, 2]),
("SamplePerfectTile", [1, 64]),
("SampleCategorical", 2),
("SampleComputeLocation", 6),
("SampleComputeLocation", 6),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 6, 2, 1]),
("SamplePerfectTile", [1, 2, 2, 32]),
("SamplePerfectTile", [2, 8, 4, 2]),
("SamplePerfectTile", [1, 64]),
("SampleCategorical", 2),
("SampleComputeLocation", 13),
("SampleComputeLocation", -1),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 6, 2, 1]),
("SamplePerfectTile", [1, 2, 2, 32]),
("SamplePerfectTile", [2, 8, 4, 2]),
("SamplePerfectTile", [1, 64]),
("SampleCategorical", 3),
("SampleComputeLocation", 7),
("SampleComputeLocation", -2),
]
mod = create_te_workload("TBG", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[tbg_0, tbg_1, tbg_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
Expand All @@ -2431,3 +2615,4 @@ def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3
test_cpu_nrm()
test_cpu_sfm()
test_cpu_cbr()
test_cpu_tbg()
85 changes: 85 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,90 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3
)


def test_cuda_tbg():
# fmt: off
@T.prim_func
def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":1024})
C_local = T.alloc_buffer([1, 12, 128, 128], dtype="float32", scope="local")
query_T_shared = T.alloc_buffer([1, 12, 128, 64], dtype="float32", scope="shared")
value_T_shared = T.alloc_buffer([1, 12, 64, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(4, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(192, thread="vthread.x"):
for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(32, thread="threadIdx.x"):
for i4_0 in T.serial(8):
for ax0_ax1_ax2_ax3_fused in T.serial(12288):
with T.block("query_T_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 1024)
v2 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 1024 // 8)
v3 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 8)
T.reads(query[v0, v2, v1, v3])
T.writes(query_T_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":3})
query_T_shared[v0, v1, v2, v3] = query[v0, v2, v1, v3]
for ax0_ax1_ax2_ax3_fused in T.serial(3072):
with T.block("value_T_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 256)
v2 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32)
v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(value[v0, v3, v1, v2])
T.writes(value_T_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
value_T_shared[v0, v1, v2, v3] = value[v0, v3, v1, v2]
for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(4, 1, 2, 1, 1, 2, 1, 1, 4, 1):
with T.block("C"):
b = T.axis.spatial(1, i0_4 + i0_3)
h = T.axis.spatial(12, i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + i1_3)
i = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + i2_3 * 4 + i2_4)
j = T.axis.spatial(128, i3_4 + i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + i3_3)
k = T.axis.reduce(64, i4_0 * 8 + i4_1 * 2 + i4_2)
T.reads(query_T_shared[b, h, i, k], value_T_shared[b, h, k, j])
T.writes(C_local[b, h, i, j])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
C_local[b, h, i, j] = T.float32(0)
C_local[b, h, i, j] = C_local[b, h, i, j] + query_T_shared[b, h, i, k] * value_T_shared[b, h, k, j]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 1):
with T.block("C_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(12, i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + ax1)
v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + ax2)
v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + ax3)
T.reads(C_local[v0, v1, v2, v3])
T.writes(C[v0, v1, v2, v3])
C[v0, v1, v2, v3] = C_local[v0, v1, v2, v3]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [1, 6, 1, 2, 1]),
("SamplePerfectTile", [1, 4, 8, 1, 4]),
("SamplePerfectTile", [4, 8, 4, 1, 1]),
("SamplePerfectTile", [8, 4, 2]),
("SampleCategorical", 2),
("SampleCategorical", 3),
("SampleCategorical", 4),
]
mod = create_te_workload("TBG", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[tbg_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
Expand All @@ -1230,3 +1314,4 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3
test_cuda_nrm()
test_cuda_sfm()
test_cuda_cbr()
test_cuda_tbg()