From 4d1292910e44152c5a72fbc122e98fd757237351 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 31 Oct 2022 11:34:25 -0700 Subject: [PATCH] [MetaSchedule] Fix thread bindings of MultiLevelTilingTensorCore --- .../schedule_rule/multi_level_tiling_tensor_core.cc | 5 +++++ src/meta_schedule/schedule_rule/schedule_rule.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index e8a03c722656..37c35248329a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -556,6 +556,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( Optional max_innermost_factor, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, bool use_software_pipeline) { + if (tile_binds.defined()) { + for (const String& tile_bind : tile_binds.value()) { + CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core."; + } + } auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index bd492d03eac6..8e4642b50ddb 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -139,7 +139,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { Array results{ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/intrin_groups, /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, + /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/