Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
bool use_software_pipeline) {
if (tile_binds.defined()) {
for (const String& tile_bind : tile_binds.value()) {
CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core.";
}
}
Comment on lines +559 to +563
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this check (along with existing tests) now cover the regression this fixes? If not adding a test to protect against the regression in the future will surely save you and others time!

Copy link
Contributor

@csullivan csullivan Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops looks like @junrushao already asked my question, I trust your judgements :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this check covered the broken case that incorrectly using threadIdx.x. As a follow up, I'm thinking also adding some e2e search space generation test for the default config for each target.

auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);

Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Array<ScheduleRule> results{ScheduleRule::MultiLevelTilingTensorCore(
/*intrin_groups=*/intrin_groups,
/*structure=*/"SSSRRSRS",
/*tile_binds=*/Array<String>{"blockIdx.x", "vthread.x", "threadIdx.x"},
/*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"},
/*max_innermost_factor=*/Integer(4),
/*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
/*reuse_read=*/
Expand Down