diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 08a8248dfdbc..4095d6ca0397 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -131,6 +131,8 @@ class Mutator : public runtime::ObjectRef { FApply f_apply, FClone f_clone, FAsString f_as_string); /*! \brief Create default mutators for LLVM */ TVM_DLL static Map DefaultLLVM(); + /*! \brief Create default mutators for x86 VNNI */ + TVM_DLL static Map DefaultVNNI(); /*! \brief Create default mutators for CUDA */ TVM_DLL static Map DefaultCUDA(); /*! \brief Create default mutators for CUDA with TensorCore */ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index a680a647956c..13fe47058740 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -152,6 +152,8 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Postproc RewriteLayout(); /*! \brief Create default postprocessors for LLVM */ TVM_DLL static Array DefaultLLVM(); + /*! \brief Create default postprocessors for x86 VNNI */ + TVM_DLL static Array DefaultVNNI(); /*! \brief Create default postprocessors for CUDA */ TVM_DLL static Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 70dec47e60bd..a3d6c7ef68bf 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -285,6 +285,8 @@ class ScheduleRule : public runtime::ObjectRef { /*! \brief Create default schedule rules for LLVM */ TVM_DLL static Array DefaultLLVM(); + /*! \brief Create default schedule rules for x86 VNNI */ + TVM_DLL static Array DefaultVNNI(); /*! \brief Create default schedule rules for CUDA */ TVM_DLL static Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 8e9bfc8bde4b..8f3d14b6c466 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -59,6 +59,8 @@ Map Mutator::DefaultLLVM() { {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; } +Map Mutator::DefaultVNNI() { return Mutator::DefaultLLVM(); } + Map Mutator::DefaultCUDA() { return Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 0738c871120f..c614f3230d59 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -59,6 +59,14 @@ Array Postproc::DefaultLLVM() { }; } +Array Postproc::DefaultVNNI() { + return Array{ + Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true), + Postproc::RewriteLayout(), + }; +} + Array Postproc::DefaultCUDA() { return Array{ Postproc::DisallowDynamicLoop(), diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index b1e8c3695d3e..e4f97c1fa673 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -85,6 +85,51 @@ Array ScheduleRule::DefaultLLVM() { }; } +Array ScheduleRule::DefaultVNNI() { + return { + ScheduleRule::ApplyCustomRule(), + ScheduleRule::InlineConstantScalars(), + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), + ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/16, + /*max_innermost_factor=*/Integer(64)), + ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/"dot_16x4_vnni", + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/64, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_explicit=*/true), + ScheduleRule::RandomComputeLocation(), + }; +} + Array ScheduleRule::DefaultCUDA() { return { ScheduleRule::ApplyCustomRule(), diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index bcc0673e5924..bd124511b83c 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -23,6 +23,13 @@ namespace meta_schedule { String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { + static const PackedFunc* f_check_vnni = + runtime::Registry::Get("tvm.topi.x86.utils.target_has_vnni"); + ICHECK(*f_check_vnni != nullptr) << "The `target_has_vnni` func is not in tvm registry."; + if (target->GetAttr("mcpu") && + (*f_check_vnni)(target->GetAttr("mcpu").value())) { + return "vnni"; + } return "llvm"; } if (target->kind->name == "hexagon") { @@ -79,6 +86,10 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultHexagon(); default_postprocs = Postproc::DefaultHexagon(); default_mutator_probs = Mutator::DefaultHexagon(); + } else if (kind == "vnni") { + default_sch_rules = ScheduleRule::DefaultVNNI(); + default_postprocs = Postproc::DefaultVNNI(); + default_mutator_probs = Mutator::DefaultVNNI(); } else { LOG(FATAL) << "Unsupported kind: " << kind; throw;