Skip to content

[Bug][MetaSchedule] Failed to tune fp16 dense_add workload of some shapes on cuda #14137

@wllqwzx

Description

@wllqwzx

I found that when tuning the fp16 tensorcore dense_add kernel, the tuning fails on some shapes and the reported error is non-deterministic.

For example, when the workload is N=1, M=1000, K=512, the tuning fails.

There are two kinds of reported errors. From my observation, the following error may be reported more frequently:

Click me
2023-02-27 14:11:46 [INFO] Logging directory: /tmp/tmp71o3_ldv/logs
2023-02-27 14:11:46 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:11:47 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
2023-02-27 14:11:48 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
Traceback (most recent call last):
  File "bug_tune_dense_add.py", line 507, in <module>
    test_tune_tir_matmul_cuda_tensor_core()
  File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core
    database = tune_tir(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir
    return tune_tasks(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks
    task_scheduler.tune(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
    _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  6: TVMFuncCall
  5: _ZN3tvm7runtime13PackedF
  4: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  3: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
  1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
  0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
  File "/mnt/disk5/wll/code/metaschedule/src/support/parallel_for.cc", line 128
RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered)

and may report this error with a lower frequency:

Click me
2023-02-27 14:20:13 [INFO] Logging directory: /tmp/tmputfxvrl5/logs
2023-02-27 14:20:13 [INFO] LocalBuilder: max_workers = 11
2023-02-27 14:20:14 [INFO] LocalRunner: max_workers = 1
2023-02-27 14:20:15 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
Traceback (most recent call last):
  File "bug_tune_dense_add.py", line 507, in <module>
    test_tune_tir_matmul_cuda_tensor_core()
  File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core
    database = tune_tir(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir
    return tune_tasks(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks
    task_scheduler.tune(
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
    _ffi_api.TaskSchedulerTune(  # type: ignore # pylint: disable=no-member
  File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  9: TVMFuncCall
  8: _ZN3tvm7runtime13PackedF
  7: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  6: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  4: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  3: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> >)
  2: tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State) const
  1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
  0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(p0_handle: T.handle, p1_handle: T.handle, p2_handle: T.handle, T_add_handle: T.handle):
        T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
        p0 = T.match_buffer(p0_handle, (T.int64(8), T.int64(512)), "float16")
        p1 = T.match_buffer(p1_handle, (T.int64(1000), T.int64(512)), "float16")
        p2 = T.match_buffer(p2_handle, (T.int64(8), T.int64(1000)), "float16")
        T_add = T.match_buffer(T_add_handle, (T.int64(8), T.int64(1000)), "float16")
        # tir.Block#0
        with T.block("root"):
        ^^^^^^^^^^^^^^^^^^^^^
            T.reads()
            ^^^^^^^^^
            T.writes()
            ^^^^^^^^^^
            T_matmul_NT = T.alloc_buffer((T.int64(8), T.int64(1000)), "float16")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            p0_reindex = T.alloc_buffer((T.int64(16), T.int64(512)), "float16")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            p1_reindex = T.alloc_buffer((T.int64(1008), T.int64(512)), "float16")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            T_matmul_NT_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="shared.dyn")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            T_matmul_NT_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="wmma.accumulator")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            p0_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(512)), "float16", scope="shared.dyn")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(16)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(512)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    with T.block("p0_reindex_reindex"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v0 = T.axis.spatial(T.int64(16), ax0)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v1 = T.axis.spatial(T.int64(512), ax1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.reads(p0[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^
                        T.writes(p0_reindex[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        p0_reindex[v0, v1] = T.if_then_else(v0 < T.int64(8), p0[v0, v1], T.float16(0))
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(1008)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(512)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    with T.block("p1_reindex_reindex"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v0 = T.axis.spatial(T.int64(1008), ax0)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v1 = T.axis.spatial(T.int64(512), ax1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.reads(p1[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^
                        T.writes(p1_reindex[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        p1_reindex[v0, v1] = T.if_then_else(v0 < T.int64(1000), p1[v0, v1], T.float16(0))
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(16)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(512)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    with T.block("p0_reindex_shared.dyn"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v0 = T.axis.spatial(T.int64(16), ax0)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v1 = T.axis.spatial(T.int64(512), ax1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.reads(p0_reindex[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.writes(p0_reindex_shared_dyn[v0, v1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        p0_reindex_shared_dyn[v0, v1] = p0_reindex[v0, v1]
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0_0_0_ax1_0_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.y"):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax0_0_1_ax1_0_1_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax0_0_2_ax1_0_2_fused in T.thread_binding(T.int64(3), thread="threadIdx.y"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax2_0_0 in range(T.int64(1)):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            for ax2_0_1 in range(T.int64(32)):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                for ax0_0_3 in range(T.int64(1)):
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    for ax1_0_3 in range(T.int64(21)):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        for ax2_0_2 in range(T.int64(1)):
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                            for ax0_0_4 in range(T.int64(1)):
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                for ax1_0_4 in range(T.int64(1)):
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                    with T.block("T_matmul_NT_o"):
                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        v0_o = T.axis.spatial(T.int64(1), ax0_0_4 + ax0_0_3)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        v1_o = T.axis.spatial(T.int64(63), ax1_0_4 + ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0_3)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        v2_o = T.axis.reduce(T.int64(32), ax2_0_0 * T.int64(32) + ax2_0_1 + ax2_0_2)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        T.reads(p0_reindex_shared_dyn[T.int64(0):T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], p1_reindex[v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)])
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f16_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f16", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "warp_execution": 1})
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        with T.init():
                                                        ^^^^^^^^^^^^^^
                                                            for ax0_1 in range(T.int64(16)):
                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                for ax1_1 in range(T.int64(16)):
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                    with T.block("T_matmul_NT_init"):
                                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        v0_i_init = T.axis.spatial(T.int64(16), ax0_1)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        v1_i_init = T.axis.spatial(T.int64(16), ax1_1)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T.reads()
                                                                        ^^^^^^^^^
                                                                        T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init])
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init] = T.float16(0)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                        for ax0_1 in range(T.int64(16)):
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                            for ax1_1 in range(T.int64(16)):
                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                for ax2_1 in range(T.int64(16)):
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                    with T.block("T_matmul_NT"):
                                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        v0_i = T.axis.spatial(T.int64(16), ax0_1)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        v1_i = T.axis.spatial(T.int64(16), ax1_1)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        v2_i = T.axis.reduce(T.int64(16), ax2_1)
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i], p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i], p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i])
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i])
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                                        T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] + p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i] * p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i]
                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax0_0 in range(T.int64(1)):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            for ax1_0 in range(T.int64(21)):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator_o"):
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    v0_o = T.axis.spatial(T.int64(1), ax0_0)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    v1_o = T.axis.spatial(T.int64(63), ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    T.writes(T_matmul_NT_reindex_shared_dyn[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)])
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f16_shared_dyn"})
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    for ax0_1 in range(T.int64(16)):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        for ax1_1 in range(T.int64(16)):
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                            with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator"):
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                v0_i = T.axis.spatial(T.int64(16), ax0_1)
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                v1_i = T.axis.spatial(T.int64(16), ax1_1)
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i])
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                T.writes(T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i])
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax0_ax1_fused in range(T.int64(16128)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        with T.block("T_matmul_NT_reindex_shared.dyn"):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v0 = T.axis.spatial(T.int64(16), ax0_ax1_fused // T.int64(1008))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v1 = T.axis.spatial(T.int64(1008), ax0_ax1_fused % T.int64(1008))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.where(ax0_ax1_fused < T.int64(8056) and ax0_ax1_fused % T.int64(1008) < T.int64(1000))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.reads(T_matmul_NT_reindex_shared_dyn[v0, v1])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.writes(T_matmul_NT[v0, v1])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.block_attr({"meta_schedule.cooperative_fetch": 1})
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T_matmul_NT[v0, v1] = T_matmul_NT_reindex_shared_dyn[v0, v1]
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(8)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(1000)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    with T.block("T_add"):
                    ^^^^^^^^^^^^^^^^^^^^^^
                        v_ax0 = T.axis.spatial(T.int64(8), ax0)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        v_ax1 = T.axis.spatial(T.int64(1000), ax1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T.writes(T_add[v_ax0, v_ax1])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]

                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and For

I tried different N, and found that when N=2, 4, 8, 12, 17, 18, 24 the tuning still fails, but when N=16, 32 it succeeds. I guess it may be because of the alignment requirement of m16n16k16 tensor core.

Expected behavior

The tuning succeeds

Environment

  • OS: Ubuntu 20.04.3 LTS
  • TVM version: main branch
  • GPU: nvidia-a100

Steps to reproduce

import tempfile
import os
import numpy as np

import tvm
import tvm.tir.tensor_intrin
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.meta_schedule.database import JSONDatabase
from tvm.target import Target
from tvm.tir import Schedule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.testing import te_workload
from tvm import tir
from tvm.script import ir as I
from tvm.script import tir as T


@I.ir_module
class Module0:
    @T.prim_func
    def main(p0: T.Buffer((T.int64(1), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(1), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(1000)), "float16")):
        T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
        # with T.block("root"):
        T_matmul_NT = T.alloc_buffer((T.int64(1), T.int64(1000)), "float16")
        for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(512)):
            with T.block("T_matmul_NT"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(p0[v_i, v_k], p1[v_j, v_k])
                T.writes(T_matmul_NT[v_i, v_j])
                with T.init():
                    T_matmul_NT[v_i, v_j] = T.float16(0)
                T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]


@I.ir_module
class Module1:
    @T.prim_func
    def main(p0: T.Buffer((T.int64(16), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(16), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(16), T.int64(1000)), "float16")):
        T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
        # with T.block("root"):
        T_matmul_NT = T.alloc_buffer((T.int64(16), T.int64(1000)), "float16")
        for i, j, k in T.grid(T.int64(16), T.int64(1000), T.int64(512)):
            with T.block("T_matmul_NT"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(p0[v_i, v_k], p1[v_j, v_k])
                T.writes(T_matmul_NT[v_i, v_j])
                with T.init():
                    T_matmul_NT[v_i, v_j] = T.float16(0)
                T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
        for ax0, ax1 in T.grid(T.int64(16), T.int64(1000)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]


def tune_dense_add_cuda_tensor_core():
    target = Target("nvidia/nvidia-a100")
    with tempfile.TemporaryDirectory() as work_dir:
        database = ms.database.Database.create(kind="json", work_dir=work_dir)
        mod = Module0   # failed
        # mod = Module1   # success
        database = tune_tir(
            mod=mod,
            target=target,
            work_dir=work_dir,
            num_trials_per_iter=10,
            max_trials_global=10,
            strategy="replay-trace",
            # strategy="evolutionary",
            database=database,
        )
        sch = ms.tir_integration.compile_tir(database, mod, target)
        if sch is None:
            print("No valid schedule found!")
        else:
            from tvm.contrib import nvcc
            import numpy as np
            ctx = tvm.cuda()
            if nvcc.have_tensorcore(ctx.compute_version):
                with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
                    func = tvm.build(sch.mod["main"], [], "cuda")
                    # print(func.imported_modules[0].get_source())
                    # print(sch.mod.script())
                    # print(sch.trace)


if __name__ == "__main__":
    tune_dense_add_cuda_tensor_core()

Triage

  • tune:meta_schedule

cc @ibsidorenko

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions