Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ std::vector<State> MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta
const BlockRV& block_rv = state->block_rv;
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
Array<LoopRV> loops = sch->GetLoops(block_rv);
if (!(loops.size() == 3 || !state->is_mma)) {
LOG(DEBUG) << "The MMA tensor core only supports SSR loops now";
return {};
}
std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv));
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
Expand All @@ -344,7 +348,6 @@ std::vector<State> MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta
state->tile_factors.resize(tiles.size());
std::vector<Array<tir::ExprRV>> tile_factors;
tile_factors.resize(tiles.size());
ICHECK(loops.size() == 3 || !state->is_mma) << "The MMA tensor core only supports SSR loops now";
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
Expand Down
23 changes: 20 additions & 3 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
}

Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Array<Map<String, String>> intrin_groups = {
Array<Map<String, String>> wmma_intrin_groups = {
// Tensor Cores f32 += f16 * f16
{
{"init", "wmma_fill_16x16x16_f32"},
Expand Down Expand Up @@ -217,6 +217,8 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
{"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
{"store", "wmma_store_16x16x16_s32_shared_dyn"},
},
};
Array<Map<String, String>> mma_intrin_groups = {
// Tensor Core MMA
{
{"init", "mma_init_m16n8k8_f16"},
Expand All @@ -236,7 +238,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Array<ScheduleRule> results{
ScheduleRule::ApplyCustomRule(),
ScheduleRule::MultiLevelTilingTensorCore(
/*intrin_groups=*/intrin_groups,
/*intrin_groups=*/wmma_intrin_groups,
/*structure=*/"SSSRRSRS",
/*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"},
/*max_innermost_factor=*/Integer(4),
Expand All @@ -249,7 +251,22 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{2}}, //
{"scope", String("shared.dyn")}},
/*use_software_pipeline=*/false) //
/*use_software_pipeline=*/false), //
ScheduleRule::MultiLevelTilingTensorCore(
/*intrin_groups=*/mma_intrin_groups,
/*structure=*/"SSSRRSRS",
/*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"},
/*max_innermost_factor=*/Integer(4),
/*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
/*reuse_read=*/
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{4}}, //
{"scope", String("shared.dyn")}},
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("no")},
{"levels", Array<Integer>{2}}, //
{"scope", String("shared.dyn")}},
/*use_software_pipeline=*/true) //
};
Array<ScheduleRule> append = ScheduleRule::DefaultCUDA();
results.insert(results.end(), append.begin() + 1, append.end());
Expand Down