From d41bb0b708ab16a9ea0d27bd7cc3d453d73460f0 Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Sat, 29 Jul 2023 05:59:34 +0000 Subject: [PATCH] fix --- .../multi_level_tiling_tensor_core.cc | 5 +++- .../schedule_rule/schedule_rule.cc | 23 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 18bd58510d85..d519187d303f 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -336,6 +336,10 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types Array loops = sch->GetLoops(block_rv); + if (!(loops.size() == 3 || !state->is_mma)) { + LOG(DEBUG) << "The MMA tensor core only supports SSR loops now"; + return {}; + } std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it @@ -344,7 +348,6 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta state->tile_factors.resize(tiles.size()); std::vector> tile_factors; tile_factors.resize(tiles.size()); - ICHECK(loops.size() == 3 || !state->is_mma) << "The MMA tensor core only supports SSR loops now"; for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 2a5efcd76033..3be264332461 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -171,7 +171,7 @@ Array ScheduleRule::DefaultCUDA() { } Array ScheduleRule::DefaultCUDATensorCore() { - Array> intrin_groups = { + Array> wmma_intrin_groups = { // Tensor Cores f32 += f16 * f16 { {"init", "wmma_fill_16x16x16_f32"}, @@ -217,6 +217,8 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"compute", "wmma_sync_16x16x16_s8s8s32_trans"}, {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, + }; + Array> mma_intrin_groups = { // Tensor Core MMA { {"init", "mma_init_m16n8k8_f16"}, @@ -236,7 +238,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { Array results{ ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTilingTensorCore( - /*intrin_groups=*/intrin_groups, + /*intrin_groups=*/wmma_intrin_groups, /*structure=*/"SSSRRSRS", /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), @@ -249,7 +251,22 @@ Array ScheduleRule::DefaultCUDATensorCore() { Map{{"req", String("must")}, {"levels", Array{2}}, // {"scope", String("shared.dyn")}}, - /*use_software_pipeline=*/false) // + /*use_software_pipeline=*/false), // + ScheduleRule::MultiLevelTilingTensorCore( + /*intrin_groups=*/mma_intrin_groups, + /*structure=*/"SSSRRSRS", + /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*max_innermost_factor=*/Integer(4), + /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*reuse_read=*/ + Map{{"req", String("must")}, + {"levels", Array{4}}, // + {"scope", String("shared.dyn")}}, + /*reuse_write=*/ + Map{{"req", String("no")}, + {"levels", Array{2}}, // + {"scope", String("shared.dyn")}}, + /*use_software_pipeline=*/true) // }; Array append = ScheduleRule::DefaultCUDA(); results.insert(results.end(), append.begin() + 1, append.end());