-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Closed
Copy link
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
I found that when tuning the fp16 tensorcore dense_add kernel, the tuning fails on some shapes and the reported error is non-deterministic.
For example, when the workload is N=1, M=1000, K=512, the tuning fails.
There are two kinds of reported errors. From my observation, the following error may be reported more frequently:
Click me
2023-02-27 14:11:46 [INFO] Logging directory: /tmp/tmp71o3_ldv/logs
2023-02-27 14:11:46 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:11:47 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
Traceback (most recent call last):
File "bug_tune_dense_add.py", line 507, in <module>
test_tune_tir_matmul_cuda_tensor_core()
File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core
database = tune_tir(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir
return tune_tasks(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks
task_scheduler.tune(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
6: TVMFuncCall
5: _ZN3tvm7runtime13PackedF
4: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
3: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
File "/mnt/disk5/wll/code/metaschedule/src/support/parallel_for.cc", line 128
RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered)
and may report this error with a lower frequency:
Click me
2023-02-27 14:20:13 [INFO] Logging directory: /tmp/tmputfxvrl5/logs
2023-02-27 14:20:13 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:20:14 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:20:15 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
Traceback (most recent call last):
File "bug_tune_dense_add.py", line 507, in <module>
test_tune_tir_matmul_cuda_tensor_core()
File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core
database = tune_tir(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir
return tune_tasks(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks
task_scheduler.tune(
File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
9: TVMFuncCall
8: _ZN3tvm7runtime13PackedF
7: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
6: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
4: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
3: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> >)
2: tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State) const
1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(p0_handle: T.handle, p1_handle: T.handle, p2_handle: T.handle, T_add_handle: T.handle):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
p0 = T.match_buffer(p0_handle, (T.int64(8), T.int64(512)), "float16")
p1 = T.match_buffer(p1_handle, (T.int64(1000), T.int64(512)), "float16")
p2 = T.match_buffer(p2_handle, (T.int64(8), T.int64(1000)), "float16")
T_add = T.match_buffer(T_add_handle, (T.int64(8), T.int64(1000)), "float16")
# tir.Block#0
with T.block("root"):
^^^^^^^^^^^^^^^^^^^^^
T.reads()
^^^^^^^^^
T.writes()
^^^^^^^^^^
T_matmul_NT = T.alloc_buffer((T.int64(8), T.int64(1000)), "float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex = T.alloc_buffer((T.int64(16), T.int64(512)), "float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p1_reindex = T.alloc_buffer((T.int64(1008), T.int64(512)), "float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="shared.dyn")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="wmma.accumulator")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(512)), "float16", scope="shared.dyn")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p0_reindex_reindex"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0[v0, v1])
^^^^^^^^^^^^^^^^^^^
T.writes(p0_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex[v0, v1] = T.if_then_else(v0 < T.int64(8), p0[v0, v1], T.float16(0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(1008)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p1_reindex_reindex"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(1008), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p1[v0, v1])
^^^^^^^^^^^^^^^^^^^
T.writes(p1_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p1_reindex[v0, v1] = T.if_then_else(v0 < T.int64(1000), p1[v0, v1], T.float16(0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(512)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("p0_reindex_shared.dyn"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(512), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0_reindex[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(p0_reindex_shared_dyn[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
p0_reindex_shared_dyn[v0, v1] = p0_reindex[v0, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_0_ax1_0_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.y"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_1_ax1_0_1_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_2_ax1_0_2_fused in T.thread_binding(T.int64(3), thread="threadIdx.y"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_0 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_1 in range(T.int64(32)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_3 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0_3 in range(T.int64(21)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_0_2 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0_4 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0_4 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_o"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_o = T.axis.spatial(T.int64(1), ax0_0_4 + ax0_0_3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_o = T.axis.spatial(T.int64(63), ax1_0_4 + ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0_3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2_o = T.axis.reduce(T.int64(32), ax2_0_0 * T.int64(32) + ax2_0_1 + ax2_0_2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(p0_reindex_shared_dyn[T.int64(0):T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], p1_reindex[v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f16_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f16", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "warp_execution": 1})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.init():
^^^^^^^^^^^^^^
for ax0_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_init"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i_init = T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i_init = T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads()
^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init] = T.float16(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i = T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i = T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2_i = T.axis.reduce(T.int64(16), ax2_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i], p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i], p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] + p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i] * p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_0 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_0 in range(T.int64(21)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator_o"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_o = T.axis.spatial(T.int64(1), ax0_0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_o = T.axis.spatial(T.int64(63), ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f16_shared_dyn"})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in range(T.int64(16)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0_i = T.axis.spatial(T.int64(16), ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1_i = T.axis.spatial(T.int64(16), ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_ax1_fused in range(T.int64(16128)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_matmul_NT_reindex_shared.dyn"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(16), ax0_ax1_fused // T.int64(1008))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(1008), ax0_ax1_fused % T.int64(1008))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.where(ax0_ax1_fused < T.int64(8056) and ax0_ax1_fused % T.int64(1008) < T.int64(1000))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT_reindex_shared_dyn[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul_NT[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"meta_schedule.cooperative_fetch": 1})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul_NT[v0, v1] = T_matmul_NT_reindex_shared_dyn[v0, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(8)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(1000)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("T_add"):
^^^^^^^^^^^^^^^^^^^^^^
v_ax0 = T.axis.spatial(T.int64(8), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v_ax1 = T.axis.spatial(T.int64(1000), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_add[v_ax0, v_ax1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and For
I tried different N, and found that when N=2, 4, 8, 12, 17, 18, 24 the tuning still fails, but when N=16, 32 it succeeds. I guess it may be because of the alignment requirement of m16n16k16 tensor core.
Expected behavior
The tuning succeeds
Environment
- OS: Ubuntu 20.04.3 LTS
- TVM version: main branch
- GPU: nvidia-a100
Steps to reproduce
import tempfile
import os
import numpy as np
import tvm
import tvm.tir.tensor_intrin
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.meta_schedule.database import JSONDatabase
from tvm.target import Target
from tvm.tir import Schedule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.testing import te_workload
from tvm import tir
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module0:
@T.prim_func
def main(p0: T.Buffer((T.int64(1), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(1), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(1000)), "float16")):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
# with T.block("root"):
T_matmul_NT = T.alloc_buffer((T.int64(1), T.int64(1000)), "float16")
for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(512)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(p0[v_i, v_k], p1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float16(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]
@I.ir_module
class Module1:
@T.prim_func
def main(p0: T.Buffer((T.int64(16), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(16), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(16), T.int64(1000)), "float16")):
T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
# with T.block("root"):
T_matmul_NT = T.alloc_buffer((T.int64(16), T.int64(1000)), "float16")
for i, j, k in T.grid(T.int64(16), T.int64(1000), T.int64(512)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(p0[v_i, v_k], p1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float16(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
for ax0, ax1 in T.grid(T.int64(16), T.int64(1000)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]
def tune_dense_add_cuda_tensor_core():
target = Target("nvidia/nvidia-a100")
with tempfile.TemporaryDirectory() as work_dir:
database = ms.database.Database.create(kind="json", work_dir=work_dir)
mod = Module0 # failed
# mod = Module1 # success
database = tune_tir(
mod=mod,
target=target,
work_dir=work_dir,
num_trials_per_iter=10,
max_trials_global=10,
strategy="replay-trace",
# strategy="evolutionary",
database=database,
)
sch = ms.tir_integration.compile_tir(database, mod, target)
if sch is None:
print("No valid schedule found!")
else:
from tvm.contrib import nvcc
import numpy as np
ctx = tvm.cuda()
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
func = tvm.build(sch.mod["main"], [], "cuda")
# print(func.imported_modules[0].get_source())
# print(sch.mod.script())
# print(sch.trace)
if __name__ == "__main__":
tune_dense_add_cuda_tensor_core()Triage
- tune:meta_schedule
cc @ibsidorenko
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug