From 223f6ffaf1aa9016e89615eca07b713a82e5ea81 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 13 Oct 2022 04:47:49 +0900 Subject: [PATCH 01/28] Introduce new module equality to extract only anchor block tasks --- src/meta_schedule/default_schedule.cc | 149 ++++++++++++++++++ src/meta_schedule/default_schedule.h | 37 +++++ src/meta_schedule/module_equality.cc | 47 ++++++ .../schedule_rule/schedule_rule.cc | 19 +-- .../space_generator/space_generator.cc | 8 +- src/relay/backend/task_extraction.cc | 67 ++++++-- src/relay/backend/te_compiler_cache.cc | 12 +- src/target/target_kind.cc | 3 + .../schedule/primitive/cache_read_write.cc | 12 +- src/tir/schedule/trace.cc | 16 +- .../test_meta_schedule_relay_integration.py | 4 + 11 files changed, 337 insertions(+), 37 deletions(-) create mode 100644 src/meta_schedule/default_schedule.cc create mode 100644 src/meta_schedule/default_schedule.h diff --git a/src/meta_schedule/default_schedule.cc b/src/meta_schedule/default_schedule.cc new file mode 100644 index 000000000000..2519c3e3f702 --- /dev/null +++ b/src/meta_schedule/default_schedule.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "default_schedule.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../printer/text_printer.h" +#include "../tir/schedule/analysis.h" + +namespace tvm { +namespace meta_schedule { + +static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal"}; + +ScheduleRule GetDefaultAutoInline(const std::string& target_name) { + if (target_name == "llvm" || target_name == "hexagon") { + return ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}); + } else if (gpu_targets.count(target_name)) { + return ScheduleRule::AutoInline( + /*into_producer=*/true, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/false, + /*require_injective=*/false, + /*require_ordered=*/false, + /*disallow_op=*/Array{}); + } + LOG(FATAL) << "Unsupported target " << target_name; + return ScheduleRule(nullptr); +} + +std::set GetBlockNames(const IRModule& mod) { + struct BlockNameCollector : public tir::StmtVisitor { + void VisitStmt_(const tir::BlockNode* block) override { + if (block->name_hint == "root") { + StmtVisitor::VisitStmt(block->body); + } else { + block_names.insert(block->name_hint); + } + } + std::set block_names; + }; + + auto prim_func = tir::FindEntryFunc(mod, nullptr); + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; +} + +std::vector GetUnscheduledBlocks(const tir::Schedule& sch_orig, + const tir::Schedule& sch) { + auto block_names_orig = GetBlockNames(sch_orig->mod()); + auto block_names = GetBlockNames(sch->mod()); + + std::vector common_blocks; + + std::set_intersection(block_names_orig.begin(), block_names_orig.end(), block_names.begin(), + block_names.end(), std::back_inserter(common_blocks)); + + auto is_scheduled = [=](const std::string& block_name) { + auto loops = sch->GetLoops(sch->GetBlock(block_name)); + auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); + if (loops.size() != loops.size()) { + return true; + } + for (size_t i = 0; i < loops.size(); ++i) { + auto loop = sch->Get(loops[i]); + auto loop_orig = sch_orig->Get(loops_orig[i]); + if (loop->kind != loop_orig->kind) { + return true; + } + } + return false; + }; + + std::vector unscheduled_blocks; + + for (auto name : common_blocks) { + if (!is_scheduled(name)) { + unscheduled_blocks.push_back(sch->GetBlock(name)); + } + } + + return unscheduled_blocks; +} + +void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target) { + auto sch_orig = sch->Copy(); + anchor_trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + + auto unscheduled_blocks = GetUnscheduledBlocks(sch_orig, sch); + + if (unscheduled_blocks.empty()) { + // All blocks have already been scheduled. + // e.g. Applying a trace from conv2d -> add to conv2d -> subtract + return; + } + + auto inline_rule = GetDefaultAutoInline(target->kind->name); + Optional last_block; + + for (auto block : unscheduled_blocks) { + auto sch_copy = sch->Copy(); + inline_rule->Apply(sch, block); + if (tvm::StructuralEqual()(sch->mod(), sch_copy->mod())) { + ICHECK(!last_block.defined()); + last_block = block; + } + } + + if (last_block.defined()) { + sch->ReverseComputeInline(last_block.value()); + } +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/default_schedule.h b/src/meta_schedule/default_schedule.h new file mode 100644 index 000000000000..c4e662aedcd6 --- /dev/null +++ b/src/meta_schedule/default_schedule.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ +#define TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target); + +ScheduleRule GetDefaultAutoInline(const std::string& target_name); + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index caa7da170bd6..f00bcf35abd5 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -21,10 +21,14 @@ #include #include #include +#include +#include +#include #include #include "../node/ndarray_hash_equal.h" +#include "../tir/schedule/analysis.h" namespace tvm { namespace meta_schedule { @@ -65,6 +69,34 @@ class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { } }; +const tir::BlockNode* GetAnchorBlock(IRModule mod) { + using namespace tir; + + struct BlockCollector : public StmtVisitor { + void VisitStmt_(const BlockNode* block) override { + blocks.push_back(block); + StmtVisitor::VisitStmt(block->body); + } + std::vector blocks; + }; + + auto prim_func = FindEntryFunc(mod, nullptr); + BlockCollector collector; + collector(prim_func->body); + + ICHECK(collector.blocks.size() > 0); + + for (auto block : collector.blocks) { + if (!block->reads.empty() && !block->writes.empty()) { + return block; + } + } + + LOG(FATAL) << "Cannot find a suitable anchor block"; + + return collector.blocks[0]; +} + class ModuleEqualityIgnoreNDArray : public ModuleEquality { public: size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); } @@ -73,11 +105,26 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { } }; +class ModuleEqualityAnchorBlock : public ModuleEquality { + size_t Hash(IRModule mod) const { + auto anchor_block = GetAnchorBlock(mod); + return SHashHandlerIgnoreNDArray().Hash(GetRef(anchor_block), false); + } + bool Equal(IRModule lhs, IRModule rhs) const { + auto anchor_block_lhs = GetAnchorBlock(lhs); + auto anchor_block_rhs = GetAnchorBlock(rhs); + return SEqualHandlerIgnoreNDArray().Equal(GetRef(anchor_block_lhs), + GetRef(anchor_block_rhs), false); + } +}; + std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { if (mod_eq_name == "structural") { return std::make_unique(); } else if (mod_eq_name == "ignore-ndarray") { return std::make_unique(); + } else if (mod_eq_name == "anchor-block") { + return std::make_unique(); } LOG(FATAL) << "Unknown module equality " << mod_eq_name; return nullptr; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 8333833bfafa..d12edb254c9d 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "../default_schedule.h" namespace tvm { namespace meta_schedule { @@ -53,14 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( Array ScheduleRule::DefaultLLVM() { return { - ScheduleRule::AutoInline( - /*into_producer=*/false, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/true, - /*require_injective=*/true, - /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + GetDefaultAutoInline("llvm"), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -98,14 +92,7 @@ Array ScheduleRule::DefaultCUDA() { Map{{"req", String("must")}, {"levels", Array{3}}, // {"scope", String("local")}}), - ScheduleRule::AutoInline( - /*into_producer=*/true, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/false, - /*require_injective=*/false, - /*require_ordered=*/false, - /*disallow_op=*/Array{}), + GetDefaultAutoInline("cuda"), ScheduleRule::CrossThreadReduction( /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 53107bafb2c0..2aba699c675f 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -45,12 +45,12 @@ String GetRuleKindFromTarget(const Target& target) { } return "cuda"; } - if (target->kind->name == "rocm") { - return "cuda"; - } - if (target->kind->name == "vulkan") { + + const std::unordered_set other_gpu_targets{"rocm", "vulkan", "metal"}; + if (other_gpu_targets.count(target->kind->name)) { return "cuda"; } + LOG(FATAL) << "Unsupported target: " << target; throw; } diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 430b551a3b9e..da14e624015a 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -22,6 +22,8 @@ #include #include +#include + #include "../../meta_schedule/module_equality.h" #include "../../te/operation/create_primfunc.h" #include "./te_compiler_cache.h" @@ -31,6 +33,25 @@ namespace tvm { namespace relay { namespace backend { +class OpCounter : public ExprVisitor { + public: + static size_t GetOpCount(relay::Function func) { + OpCounter counter; + counter(func->body); + return counter.count; + } + + private: + void VisitExpr_(const CallNode* call) final { + if (call->op->IsInstance()) { + ++count; + } + ExprVisitor::VisitExpr_(call); + } + + size_t count{0}; +}; + Array ExtractTask(IRModule mod, Target target, Map params, String mod_eq_name) { @@ -52,35 +73,53 @@ Array ExtractTask(IRModule mod, Target target, std::unordered_map cache( /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq)); - PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &tir_converter](const Expr& exp) { + std::vector> lower_results; + + PostOrderVisit(mod->Lookup("main"), [&lower_results, &target, &tir_converter](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); if (Optional f = tir_converter(inputs_outputs, constants)) { IRModule tir_mod = PrimFuncToIRModule(f.value()); - - auto it = cache.find(tir_mod); - if (it != cache.end()) { - it->second->weight += 1; - return; - } - - // Note that the cache is key-ed on the tir mod, rather than the relay mod - IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); - ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1); - tasks.push_back(task); - cache.emplace(tir_mod, task); + lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod)); } } }); + + std::vector indices(lower_results.size()); + std::iota(indices.begin(), indices.end(), 0); + + if (mod_eq_name == "anchor-block") { + std::vector op_counts(lower_results.size()); + for (size_t i = 0; i < op_counts.size(); ++i) { + op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i])); + } + std::sort(indices.begin(), indices.end(), + [&op_counts](int i1, int i2) { return op_counts[i1] <= op_counts[i2]; }); + } + + for (auto i : indices) { + const auto& [fused_name, relay_func, tir_mod] = lower_results[i]; + auto it = cache.find(tir_mod); + if (it != cache.end()) { + it->second->weight += 1; + continue; + } + // Note that the cache is key-ed on the tir mod, rather than the relay mod + IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); + ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1); + tasks.push_back(task); + cache.emplace(tir_mod, task); + } + // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); + NameSupply name_supply = NameSupply(""); for (ExtractedTask task : tasks) { task->task_name = name_supply->FreshName(task->task_name); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e7326ed5dd4d..d8a915b06ee7 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -52,6 +52,8 @@ #include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" +#include "../src/meta_schedule/default_schedule.h" +#include "../src/meta_schedule/module_equality.h" #include "../transforms/meta_schedule_layout_rewrite.h" #include "utils.h" @@ -614,9 +616,17 @@ class ScheduleBuilder : public ExprVisitor { MetaScheduleLayoutRewriter::LayoutQueuePush(index_map); } } + Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); - record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + + if (!meta_schedule::ModuleEquality::Create("ignore-ndarray") + ->Equal(query_mod, opt_record.value()->workload->mod)) { + meta_schedule::ScheduleFusedBlocks(sch, record->trace, target_); + } else { + record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + } + IRModule mod = sch->mod(); ICHECK_EQ(mod->functions.size(), 1); mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)( diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index a95f55357f2d..ef350004ad52 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -354,8 +354,11 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More // information about this limitation can be found here: // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc +// See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(32768)) .add_attr_option("thread_warp_size", Integer(16)) .add_attr_option("max_function_args", Integer(31)) .set_default_keys({"metal", "gpu"}); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index adadb46852cc..896d5154876c 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -297,7 +297,17 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) * \param stage The stage to be inserted * \return A SeqStmt, the result after insertion */ -SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { +Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { + if (const auto* alloc = stmt.as()) { + auto seq_stmt = InsertCacheStage(alloc->body, pos, stage); + return AllocateConst(alloc->buffer_var, + alloc->dtype, + alloc->extents, + alloc->data, + seq_stmt, + alloc->annotations, + alloc->span); + } if (const auto* seq_stmt = stmt.as()) { ObjectPtr result = make_object(*seq_stmt); result->seq.insert(result->seq.begin() + pos, stage); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index b90b6b85960f..41931a685eef 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -255,6 +255,7 @@ void TraceNode::ApplyToSchedule( const Optional& decision)> decision_provider) const { std::unordered_map rv_map; + static auto kind_get_child_blocks = tir::InstructionKind::Get("GetChildBlocks"); for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; @@ -266,7 +267,20 @@ void TraceNode::ApplyToSchedule( decision = decision_provider(inst, inputs, attrs, decision); } Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); - TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + if (inst->kind.same_as(kind_get_child_blocks)) { + // We want to allow a trace generated for a single conv2d block to be applied to + // conv2d -> elemwise blocks, where two conv2d are the same workload. + // GetChildBlocks returns a different number of blocks for the two cases above, which violates + // the assumption made by TranslateAddOutputRVs: old_outputs.size() == new_outputs.size(). + // We workaround this problem by assuming that the prefix of the "new" outputs matches with + // the "old" outputs, and truncating the new outputs accordingly. + ICHECK(inst->outputs.size() <= outputs.size()); + TranslateAddOutputRVs( + inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); + } else { + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } } } diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 9a1c9e8dc7f5..f6cef129a638 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -106,6 +106,10 @@ def test_meta_schedule_integration_extract_from_resnet(): for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name + extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params, + module_equality="anchor-block") + assert len(extracted_tasks) == 15 + @requires_torch def test_meta_schedule_integration_extract_from_bert_base(): From f5ca225e34ed3129c5a250178d6a5db2788ba50d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Oct 2022 20:23:46 +0900 Subject: [PATCH 02/28] enabling application of anchor trace to different subgraph --- include/tvm/tir/schedule/trace.h | 3 + src/meta_schedule/default_schedule.cc | 61 +---------- src/relay/backend/task_extraction.cc | 1 - src/relay/backend/te_compiler_cache.cc | 9 +- src/tir/schedule/trace.cc | 139 ++++++++++++++++++++++--- 5 files changed, 135 insertions(+), 78 deletions(-) diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index b6b3b57226c8..e2ee2867d032 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -158,6 +158,9 @@ class Trace : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); }; +class BlockRV; +std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace); + } // namespace tir } // namespace tvm diff --git a/src/meta_schedule/default_schedule.cc b/src/meta_schedule/default_schedule.cc index 2519c3e3f702..c4fbba13c778 100644 --- a/src/meta_schedule/default_schedule.cc +++ b/src/meta_schedule/default_schedule.cc @@ -31,6 +31,7 @@ #include "../printer/text_printer.h" #include "../tir/schedule/analysis.h" +#include "../tir/schedule/utils.h" namespace tvm { namespace meta_schedule { @@ -61,66 +62,8 @@ ScheduleRule GetDefaultAutoInline(const std::string& target_name) { return ScheduleRule(nullptr); } -std::set GetBlockNames(const IRModule& mod) { - struct BlockNameCollector : public tir::StmtVisitor { - void VisitStmt_(const tir::BlockNode* block) override { - if (block->name_hint == "root") { - StmtVisitor::VisitStmt(block->body); - } else { - block_names.insert(block->name_hint); - } - } - std::set block_names; - }; - - auto prim_func = tir::FindEntryFunc(mod, nullptr); - BlockNameCollector collector; - collector(prim_func->body); - return collector.block_names; -} - -std::vector GetUnscheduledBlocks(const tir::Schedule& sch_orig, - const tir::Schedule& sch) { - auto block_names_orig = GetBlockNames(sch_orig->mod()); - auto block_names = GetBlockNames(sch->mod()); - - std::vector common_blocks; - - std::set_intersection(block_names_orig.begin(), block_names_orig.end(), block_names.begin(), - block_names.end(), std::back_inserter(common_blocks)); - - auto is_scheduled = [=](const std::string& block_name) { - auto loops = sch->GetLoops(sch->GetBlock(block_name)); - auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); - if (loops.size() != loops.size()) { - return true; - } - for (size_t i = 0; i < loops.size(); ++i) { - auto loop = sch->Get(loops[i]); - auto loop_orig = sch_orig->Get(loops_orig[i]); - if (loop->kind != loop_orig->kind) { - return true; - } - } - return false; - }; - - std::vector unscheduled_blocks; - - for (auto name : common_blocks) { - if (!is_scheduled(name)) { - unscheduled_blocks.push_back(sch->GetBlock(name)); - } - } - - return unscheduled_blocks; -} - void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target) { - auto sch_orig = sch->Copy(); - anchor_trace->ApplyToSchedule(sch, /*remove_postproc=*/false); - - auto unscheduled_blocks = GetUnscheduledBlocks(sch_orig, sch); + auto unscheduled_blocks = tir::ApplyAnchorTrace(sch, anchor_trace); if (unscheduled_blocks.empty()) { // All blocks have already been scheduled. diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index da14e624015a..ba3bd958b714 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -119,7 +119,6 @@ Array ExtractTask(IRModule mod, Target target, // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); - NameSupply name_supply = NameSupply(""); for (ExtractedTask task : tasks) { task->task_name = name_supply->FreshName(task->task_name); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d8a915b06ee7..7b9aca994a2a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -44,6 +44,7 @@ #include #include +#include #include #include #include @@ -463,7 +464,9 @@ class AllocateConstReplaceConstant : public StmtExprMutator { // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public ExprVisitor { public: - explicit ScheduleBuilder(Target target) : target_(target) { + explicit ScheduleBuilder(Target target) + : target_(target), + mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray")) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); if (backend::IsMetaScheduleEnabled()) { @@ -620,8 +623,7 @@ class ScheduleBuilder : public ExprVisitor { Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); - if (!meta_schedule::ModuleEquality::Create("ignore-ndarray") - ->Equal(query_mod, opt_record.value()->workload->mod)) { + if (!mod_eq_structural_->Equal(query_mod, opt_record.value()->workload->mod)) { meta_schedule::ScheduleFusedBlocks(sch, record->trace, target_); } else { record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); @@ -708,6 +710,7 @@ class ScheduleBuilder : public ExprVisitor { int anchor_op_pattern_{0}; bool use_auto_scheduler_; Optional database_; + std::unique_ptr mod_eq_structural_; }; /*! diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 41931a685eef..24f279d67bc5 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -255,7 +255,6 @@ void TraceNode::ApplyToSchedule( const Optional& decision)> decision_provider) const { std::unordered_map rv_map; - static auto kind_get_child_blocks = tir::InstructionKind::Get("GetChildBlocks"); for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; @@ -267,20 +266,7 @@ void TraceNode::ApplyToSchedule( decision = decision_provider(inst, inputs, attrs, decision); } Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); - if (inst->kind.same_as(kind_get_child_blocks)) { - // We want to allow a trace generated for a single conv2d block to be applied to - // conv2d -> elemwise blocks, where two conv2d are the same workload. - // GetChildBlocks returns a different number of blocks for the two cases above, which violates - // the assumption made by TranslateAddOutputRVs: old_outputs.size() == new_outputs.size(). - // We workaround this problem by assuming that the prefix of the "new" outputs matches with - // the "old" outputs, and truncating the new outputs accordingly. - ICHECK(inst->outputs.size() <= outputs.size()); - TranslateAddOutputRVs( - inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), - &rv_map); - } else { - TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); - } + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } } @@ -535,6 +521,129 @@ struct EnterPostprocTraits : public UnpackedInstTraits { TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); +std::set GetBlockNames(const IRModule& mod) { + struct BlockNameCollector : public tir::StmtVisitor { + void VisitStmt_(const tir::BlockNode* block) override { + block_names.insert(block->name_hint); + StmtVisitor::VisitStmt(block->body); + } + std::set block_names; + }; + + auto prim_func = tir::FindEntryFunc(mod, nullptr); + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; +} + +std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace) { + std::unordered_map rv_map; + static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); + static auto kind_get_block = InstructionKind::Get("GetBlock"); + const auto block_names_orig = GetBlockNames(sch->mod()); + std::unordered_set foreign_blocks; + std::unordered_set foreign_loops; + std::set scheduled_blocks; + + const auto sch_orig = sch->Copy(); + + auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { + for (auto input : inst->inputs) { + if (!input.defined()) continue; + if ((input->IsInstance() && foreign_blocks.count(Downcast(input))) || + (input->IsInstance() && foreign_loops.count(Downcast(input)))) { + return false; + } + } + return true; + }; + + for (const auto& inst : anchor_trace->insts) { + if (!is_inst_applicable(inst)) { + for (auto output : inst->outputs) { + if (output->IsInstance()) { + foreign_blocks.insert(Downcast(output)); + } else if (output->IsInstance()) { + foreign_loops.insert(Downcast(output)); + } + } + continue; + } + + if (inst->kind.same_as(kind_get_block)) { + auto find_prefix_any = [&block_names_orig](const std::string& block_name) { + for (auto name : block_names_orig) { + if (block_name.find(name) == 0) { + return true; + } + } + return false; + }; + + auto block_name = Downcast(inst->attrs[0]); + ICHECK(block_name.defined()); + + if (!find_prefix_any(block_name)) { + auto block = Downcast(inst->outputs[0]); + foreign_blocks.insert(block); + continue; + } else { + scheduled_blocks.insert(block_name); + } + } + + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + Optional decision = anchor_trace->GetDecision(inst); + Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); + + if (inst->kind.same_as(kind_get_child_blocks)) { + // We want to allow a trace generated for a single conv2d block to be applied to + // conv2d -> elemwise blocks, where two conv2d are the same workload. + // GetChildBlocks returns a different number of blocks for the two cases above, which + // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == + // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" + // outputs matches with the "old" outputs, and truncating the new outputs accordingly. + ICHECK(inst->outputs.size() <= outputs.size()); + TranslateAddOutputRVs( + inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); + } else { + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } + } + + const auto block_names_now = GetBlockNames(sch->mod()); + + auto is_scheduled = [=, &scheduled_blocks](const std::string& block_name) { + if (!block_names_now.count(block_name) || scheduled_blocks.count(block_name)) { + return true; + } + auto loops = sch->GetLoops(sch->GetBlock(block_name)); + auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); + if (loops.size() != loops_orig.size()) { + return true; + } + for (size_t i = 0; i < loops.size(); ++i) { + auto loop = sch->Get(loops[i]); + auto loop_orig = sch_orig->Get(loops_orig[i]); + if (loop->kind != loop_orig->kind) { + return true; + } + } + return false; + }; + + std::vector unscheduled_blocks; + + for (auto name : block_names_orig) { + if (!is_scheduled(name)) { + unscheduled_blocks.push_back(sch->GetBlock(name)); + } + } + + return unscheduled_blocks; +} + /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); From e6a4f21eb92bc7deba059bb47f6570d54e08b412 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Oct 2022 07:12:00 +0900 Subject: [PATCH 03/28] fixed anchor block extraction --- src/meta_schedule/module_equality.cc | 21 ++++++++++++------- .../test_meta_schedule_relay_integration.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index f00bcf35abd5..3997ad53f3da 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -26,6 +26,7 @@ #include #include +#include #include "../node/ndarray_hash_equal.h" #include "../tir/schedule/analysis.h" @@ -69,7 +70,7 @@ class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { } }; -const tir::BlockNode* GetAnchorBlock(IRModule mod) { +std::optional GetAnchorBlock(IRModule mod) { using namespace tir; struct BlockCollector : public StmtVisitor { @@ -87,14 +88,12 @@ const tir::BlockNode* GetAnchorBlock(IRModule mod) { ICHECK(collector.blocks.size() > 0); for (auto block : collector.blocks) { - if (!block->reads.empty() && !block->writes.empty()) { + if (block->init && !block->reads.empty() && !block->writes.empty()) { return block; } } - LOG(FATAL) << "Cannot find a suitable anchor block"; - - return collector.blocks[0]; + return std::nullopt; } class ModuleEqualityIgnoreNDArray : public ModuleEquality { @@ -108,13 +107,19 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = GetAnchorBlock(mod); - return SHashHandlerIgnoreNDArray().Hash(GetRef(anchor_block), false); + if (anchor_block) { + return SHashHandlerIgnoreNDArray().Hash(GetRef(*anchor_block), false); + } + return ModuleEqualityIgnoreNDArray().Hash(mod); } bool Equal(IRModule lhs, IRModule rhs) const { auto anchor_block_lhs = GetAnchorBlock(lhs); auto anchor_block_rhs = GetAnchorBlock(rhs); - return SEqualHandlerIgnoreNDArray().Equal(GetRef(anchor_block_lhs), - GetRef(anchor_block_rhs), false); + if (anchor_block_lhs && anchor_block_rhs) { + return SEqualHandlerIgnoreNDArray().Equal(GetRef(*anchor_block_lhs), + GetRef(*anchor_block_rhs), false); + } + return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); } }; diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index f6cef129a638..ddec3dc6f75c 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -108,7 +108,7 @@ def test_meta_schedule_integration_extract_from_resnet(): extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params, module_equality="anchor-block") - assert len(extracted_tasks) == 15 + assert len(extracted_tasks) == 16 @requires_torch From f107fd758ca291af3261648d8fc75bdb550c2ba4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Oct 2022 05:12:23 +0900 Subject: [PATCH 04/28] fixed UB in task extraction --- src/relay/backend/task_extraction.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index ba3bd958b714..1dd1ed3ca701 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -99,8 +99,9 @@ Array ExtractTask(IRModule mod, Target target, for (size_t i = 0; i < op_counts.size(); ++i) { op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i])); } + std::sort(indices.begin(), indices.end(), - [&op_counts](int i1, int i2) { return op_counts[i1] <= op_counts[i2]; }); + [&op_counts](int i1, int i2) { return op_counts[i1] < op_counts[i2]; }); } for (auto i : indices) { From b2dfc18fac6f3e05284c773346989b8c642c9f6d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Oct 2022 05:32:32 +0900 Subject: [PATCH 05/28] Reworked anchor trace application and inlining logic --- include/tvm/meta_schedule/schedule_rule.h | 3 +- include/tvm/tir/schedule/trace.h | 3 - src/meta_schedule/default_schedule.cc | 185 ++++++++++++++++-- src/meta_schedule/schedule_rule/auto_bind.cc | 5 +- .../schedule_rule/schedule_rule.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 9 +- src/tir/schedule/trace.cc | 123 ------------ src/tir/schedule/utils.h | 6 + 8 files changed, 185 insertions(+), 151 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 3bc30e09c74a..2dc3140ac814 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -247,7 +247,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent. * \return The schedule rule created */ - TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents); + TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents, + int max_threads_per_block = -1); /*! * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index e2ee2867d032..b6b3b57226c8 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -158,9 +158,6 @@ class Trace : public runtime::ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); }; -class BlockRV; -std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace); - } // namespace tir } // namespace tvm diff --git a/src/meta_schedule/default_schedule.cc b/src/meta_schedule/default_schedule.cc index c4fbba13c778..9f253115b3b2 100644 --- a/src/meta_schedule/default_schedule.cc +++ b/src/meta_schedule/default_schedule.cc @@ -62,8 +62,161 @@ ScheduleRule GetDefaultAutoInline(const std::string& target_name) { return ScheduleRule(nullptr); } +std::set GetBlockNames(const IRModule& mod) { + struct BlockNameCollector : public tir::StmtVisitor { + void VisitStmt_(const tir::BlockNode* block) override { + block_names.insert(block->name_hint); + StmtVisitor::VisitStmt(block->body); + } + std::set block_names; + }; + + auto prim_func = tir::FindEntryFunc(mod, nullptr); + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; +} + +std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace, + Target target) { + using namespace tir; + auto block_names_orig = GetBlockNames(sch->mod()); + const auto sch_orig = sch->Copy(); + static auto kind_get_block = InstructionKind::Get("GetBlock"); + + std::set get_block_names; + for (const auto& inst : anchor_trace->insts) { + if (inst->kind.same_as(kind_get_block)) { + auto block_name = Downcast(inst->attrs[0]); + ICHECK(block_name.defined()); + get_block_names.insert(block_name); + } + } + + auto inline_rule = GetDefaultAutoInline(target->kind->name); + + for (auto name : block_names_orig) { + auto block = sch->GetBlock(name); + if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { + // LOG(INFO) << "Inlining " << name; + inline_rule->Apply(sch, block); + } + } + + // LOG(INFO) << "After inlining ¥n" << tir::AsTVMScript(sch->mod()); + + std::unordered_map rv_map; + static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); + + std::unordered_set foreign_blocks; + std::unordered_set foreign_loops; + std::set scheduled_blocks; + + auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { + for (auto input : inst->inputs) { + if (!input.defined()) continue; + if ((input->IsInstance() && foreign_blocks.count(Downcast(input))) || + (input->IsInstance() && foreign_loops.count(Downcast(input)))) { + return false; + } + } + return true; + }; + + for (const auto& inst : anchor_trace->insts) { + if (!is_inst_applicable(inst)) { + for (auto output : inst->outputs) { + if (output->IsInstance()) { + foreign_blocks.insert(Downcast(output)); + } else if (output->IsInstance()) { + foreign_loops.insert(Downcast(output)); + } + } + continue; + } + + if (inst->kind.same_as(kind_get_block)) { + auto find_prefix_any = [&block_names_orig](const std::string& block_name) { + for (auto name : block_names_orig) { + if (block_name.find(name) == 0) { + return true; + } + } + return false; + }; + + auto block_name = Downcast(inst->attrs[0]); + ICHECK(block_name.defined()); + + if (!find_prefix_any(block_name)) { + auto block = Downcast(inst->outputs[0]); + foreign_blocks.insert(block); + continue; + } else { + scheduled_blocks.insert(block_name); + } + } + + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + Optional decision = anchor_trace->GetDecision(inst); + Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); + + if (inst->kind.same_as(kind_get_child_blocks)) { + // We want to allow a trace generated for a single conv2d block to be applied to + // conv2d -> elemwise blocks, where two conv2d are the same workload. + // GetChildBlocks returns a different number of blocks for the two cases above, which + // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == + // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" + // outputs matches with the "old" outputs, and truncating the new outputs accordingly. + ICHECK(inst->outputs.size() <= outputs.size()); + TranslateAddOutputRVs( + inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); + } else { + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } + } + + const auto block_names_now = GetBlockNames(sch->mod()); + + auto is_scheduled = [=, &scheduled_blocks](const std::string& block_name) { + if (!block_names_now.count(block_name) || scheduled_blocks.count(block_name)) { + return true; + } + auto loops = sch->GetLoops(sch->GetBlock(block_name)); + auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); + if (loops.size() != loops_orig.size()) { + return true; + } + for (size_t i = 0; i < loops.size(); ++i) { + auto loop = sch->Get(loops[i]); + auto loop_orig = sch_orig->Get(loops_orig[i]); + if (loop->kind != loop_orig->kind) { + return true; + } + } + return false; + }; + + std::vector unscheduled_blocks; + + for (auto name : block_names_orig) { + if (!is_scheduled(name)) { + unscheduled_blocks.push_back(sch->GetBlock(name)); + } + } + + return unscheduled_blocks; +} + void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target) { - auto unscheduled_blocks = tir::ApplyAnchorTrace(sch, anchor_trace); + // LOG(INFO) << anchor_trace; + + auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace, target); + + // LOG(INFO) << tir::AsTVMScript(sch->mod()); + // LOG(INFO) << unscheduled_blocks.size(); + ICHECK(unscheduled_blocks.size() <= 1); if (unscheduled_blocks.empty()) { // All blocks have already been scheduled. @@ -71,21 +224,25 @@ void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target return; } - auto inline_rule = GetDefaultAutoInline(target->kind->name); - Optional last_block; - - for (auto block : unscheduled_blocks) { - auto sch_copy = sch->Copy(); - inline_rule->Apply(sch, block); - if (tvm::StructuralEqual()(sch->mod(), sch_copy->mod())) { - ICHECK(!last_block.defined()); - last_block = block; - } - } + auto last_block_producers = sch->GetProducers(unscheduled_blocks[0]); + if (last_block_producers.size() == 1 && tir::IsSpatial(sch->GetSRef(last_block_producers[0]))) { + // Inline into the cache write stage + sch->ReverseComputeInline(unscheduled_blocks[0]); + } else if (target->kind->name == "llvm" || target->kind->name == "hexagon") { + sch->Parallel(sch->Fuse(sch->GetLoops(unscheduled_blocks[0]))); + } else if (gpu_targets.count(target->kind->name)) { + Optional max_threads_per_block = target->GetAttr("max_threads_per_block"); + ICHECK(max_threads_per_block.defined()) + << "ValueError: missing attribute `max_threads_per_block` in the target"; - if (last_block.defined()) { - sch->ReverseComputeInline(last_block.value()); + auto auto_bind_rule = + ScheduleRule::AutoBind(/*max_threadblocks=*/256, + /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}, + max_threads_per_block.value()->value); + auto_bind_rule->Apply(sch, unscheduled_blocks[0]); } + + // LOG(INFO) << tir::AsTVMScript(sch->mod()); } } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 7af1418d8f3e..4d16a6d4d65d 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -208,10 +208,11 @@ Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl return {sch}; } -ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents) { +ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents, + int max_threads_per_block) { ObjectPtr n = make_object(); n->max_threadblocks_ = max_threadblocks; - n->max_threads_per_block_ = -1; + n->max_threads_per_block_ = max_threads_per_block; n->thread_extents_ = std::move(thread_extents); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index d12edb254c9d..207d331384d6 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include "../default_schedule.h" +#include "../utils.h" namespace tvm { namespace meta_schedule { diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 896d5154876c..2c86c2df2d25 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -300,13 +300,8 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { if (const auto* alloc = stmt.as()) { auto seq_stmt = InsertCacheStage(alloc->body, pos, stage); - return AllocateConst(alloc->buffer_var, - alloc->dtype, - alloc->extents, - alloc->data, - seq_stmt, - alloc->annotations, - alloc->span); + return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt, + alloc->annotations, alloc->span); } if (const auto* seq_stmt = stmt.as()) { ObjectPtr result = make_object(*seq_stmt); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 24f279d67bc5..b90b6b85960f 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -521,129 +521,6 @@ struct EnterPostprocTraits : public UnpackedInstTraits { TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); -std::set GetBlockNames(const IRModule& mod) { - struct BlockNameCollector : public tir::StmtVisitor { - void VisitStmt_(const tir::BlockNode* block) override { - block_names.insert(block->name_hint); - StmtVisitor::VisitStmt(block->body); - } - std::set block_names; - }; - - auto prim_func = tir::FindEntryFunc(mod, nullptr); - BlockNameCollector collector; - collector(prim_func->body); - return collector.block_names; -} - -std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace) { - std::unordered_map rv_map; - static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); - static auto kind_get_block = InstructionKind::Get("GetBlock"); - const auto block_names_orig = GetBlockNames(sch->mod()); - std::unordered_set foreign_blocks; - std::unordered_set foreign_loops; - std::set scheduled_blocks; - - const auto sch_orig = sch->Copy(); - - auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { - for (auto input : inst->inputs) { - if (!input.defined()) continue; - if ((input->IsInstance() && foreign_blocks.count(Downcast(input))) || - (input->IsInstance() && foreign_loops.count(Downcast(input)))) { - return false; - } - } - return true; - }; - - for (const auto& inst : anchor_trace->insts) { - if (!is_inst_applicable(inst)) { - for (auto output : inst->outputs) { - if (output->IsInstance()) { - foreign_blocks.insert(Downcast(output)); - } else if (output->IsInstance()) { - foreign_loops.insert(Downcast(output)); - } - } - continue; - } - - if (inst->kind.same_as(kind_get_block)) { - auto find_prefix_any = [&block_names_orig](const std::string& block_name) { - for (auto name : block_names_orig) { - if (block_name.find(name) == 0) { - return true; - } - } - return false; - }; - - auto block_name = Downcast(inst->attrs[0]); - ICHECK(block_name.defined()); - - if (!find_prefix_any(block_name)) { - auto block = Downcast(inst->outputs[0]); - foreign_blocks.insert(block); - continue; - } else { - scheduled_blocks.insert(block_name); - } - } - - Array inputs = TranslateInputRVs(inst->inputs, rv_map); - Optional decision = anchor_trace->GetDecision(inst); - Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); - - if (inst->kind.same_as(kind_get_child_blocks)) { - // We want to allow a trace generated for a single conv2d block to be applied to - // conv2d -> elemwise blocks, where two conv2d are the same workload. - // GetChildBlocks returns a different number of blocks for the two cases above, which - // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == - // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" - // outputs matches with the "old" outputs, and truncating the new outputs accordingly. - ICHECK(inst->outputs.size() <= outputs.size()); - TranslateAddOutputRVs( - inst->outputs, Array(outputs.begin(), outputs.begin() + inst->outputs.size()), - &rv_map); - } else { - TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); - } - } - - const auto block_names_now = GetBlockNames(sch->mod()); - - auto is_scheduled = [=, &scheduled_blocks](const std::string& block_name) { - if (!block_names_now.count(block_name) || scheduled_blocks.count(block_name)) { - return true; - } - auto loops = sch->GetLoops(sch->GetBlock(block_name)); - auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); - if (loops.size() != loops_orig.size()) { - return true; - } - for (size_t i = 0; i < loops.size(); ++i) { - auto loop = sch->Get(loops[i]); - auto loop_orig = sch_orig->Get(loops_orig[i]); - if (loop->kind != loop_orig->kind) { - return true; - } - } - return false; - }; - - std::vector unscheduled_blocks; - - for (auto name : block_names_orig) { - if (!is_scheduled(name)) { - unscheduled_blocks.push_back(sch->GetBlock(name)); - } - } - - return unscheduled_blocks; -} - /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c289309acc2d..d2312ea41ec2 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -442,6 +442,12 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { } } +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& rv_map); + +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* rv_map); + } // namespace tir } // namespace tvm From a4bacd1428bc61d14a638cd59e8516ce2b417230 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Oct 2022 16:11:47 +0900 Subject: [PATCH 06/28] fixed anchor block extraction for winograd --- src/meta_schedule/module_equality.cc | 65 ++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 3997ad53f3da..5e03b2ab60e1 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -30,6 +30,7 @@ #include "../node/ndarray_hash_equal.h" #include "../tir/schedule/analysis.h" +#include "tvm/tir/analysis.h" namespace tvm { namespace meta_schedule { @@ -70,6 +71,47 @@ class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { } }; +tir::Stmt GetEncolsingLoop(const tir::BlockNode* block, tir::Stmt func_body) { + using namespace tir; + struct GetRootSeqStmt : public StmtVisitor { + void VisitStmt_(const SeqStmtNode* seq) override { + result = seq; + } + const SeqStmtNode* result; + }; + + struct BlockFinder : public StmtVisitor { + BlockFinder(const BlockNode* tgt) : target(tgt) {} + + void VisitStmt_(const BlockNode* block) override { + if (block == target) { + found = true; + } + } + + const BlockNode* target; + bool found = false; + }; + + GetRootSeqStmt seq_finder; + seq_finder(func_body); + + const SeqStmtNode* seq = seq_finder.result; + + for (auto stmt: seq->seq) { + if (stmt->IsInstance()) { + BlockFinder finder(block);; + finder(stmt); + if (finder.found) { + return stmt; + } + } + } + + LOG(FATAL) << "Enclosing loop not found for a block " << GetRef(block); + return Stmt(); +} + std::optional GetAnchorBlock(IRModule mod) { using namespace tir; @@ -87,13 +129,30 @@ std::optional GetAnchorBlock(IRModule mod) { ICHECK(collector.blocks.size() > 0); + std::vector candidates; for (auto block : collector.blocks) { - if (block->init && !block->reads.empty() && !block->writes.empty()) { - return block; + if (block->init) { + candidates.push_back(block); } } - return std::nullopt; + if (candidates.empty()) { + return std::nullopt; + } else if (candidates.size() == 1) { + return candidates[0]; + } + + double best_flops = -1; + int best_idx = 0; + for (size_t i = 0; i < candidates.size(); ++i) { + auto loop = GetEncolsingLoop(candidates[i], prim_func->body); + auto flops = tir::EstimateTIRFlops(loop); + if (flops > best_flops) { + best_flops = flops; + best_idx = i; + } + } + return candidates[best_idx]; } class ModuleEqualityIgnoreNDArray : public ModuleEquality { From 718a514b464f0c22247bb74909cc449ffb168bfb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Oct 2022 16:53:56 +0900 Subject: [PATCH 07/28] fix inline logic for winograd --- src/meta_schedule/default_schedule.cc | 20 +++++++++++++++++++- src/meta_schedule/module_equality.h | 4 ++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/default_schedule.cc b/src/meta_schedule/default_schedule.cc index 9f253115b3b2..c8a520cd4663 100644 --- a/src/meta_schedule/default_schedule.cc +++ b/src/meta_schedule/default_schedule.cc @@ -32,6 +32,7 @@ #include "../printer/text_printer.h" #include "../tir/schedule/analysis.h" #include "../tir/schedule/utils.h" +#include "module_equality.h" namespace tvm { namespace meta_schedule { @@ -77,6 +78,16 @@ std::set GetBlockNames(const IRModule& mod) { return collector.block_names; } +bool IsAncestor(tir::BlockRV b1, tir::BlockRV b2, tir::Schedule sch) { + if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { + return true; + } + for (auto prod : sch->GetProducers(b2)) { + if (IsAncestor(b1, prod, sch)) return true; + } + return false; +} + std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace, Target target) { using namespace tir; @@ -94,11 +105,17 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ } auto inline_rule = GetDefaultAutoInline(target->kind->name); + auto anchor_block = GetAnchorBlock(sch->mod()); + std::optional anchor_block_rv = std::nullopt; + if (anchor_block) { + anchor_block_rv = sch->GetBlock((*anchor_block)->name_hint); + } for (auto name : block_names_orig) { auto block = sch->GetBlock(name); + if (anchor_block_rv && IsAncestor(block, *anchor_block_rv, sch)) continue; if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { - // LOG(INFO) << "Inlining " << name; + LOG(INFO) << "Inlining " << name; inline_rule->Apply(sch, block); } } @@ -210,6 +227,7 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ } void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target) { + // LOG(INFO) << tir::AsTVMScript(sch->mod()); // LOG(INFO) << anchor_trace; auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace, target); diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index 8c99b563551b..dac7417efab6 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -20,9 +20,11 @@ #define TVM_META_SCHEDULE_MODULE_EQUALITY_H_ #include +#include #include #include +#include namespace tvm { namespace meta_schedule { @@ -69,6 +71,8 @@ class ModuleEqual { const ModuleEquality& mod_eq_; }; +std::optional GetAnchorBlock(IRModule mod); + } // namespace meta_schedule } // namespace tvm From 6bbd4d8504922c4a25f10ae77565c1513e6b9a10 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Oct 2022 18:00:22 +0900 Subject: [PATCH 08/28] refactor, clean up, renaming --- include/tvm/meta_schedule/schedule_rule.h | 2 + src/meta_schedule/module_equality.cc | 103 +-------------- src/meta_schedule/module_equality.h | 4 - .../schedule_rule/schedule_rule.cc | 1 - .../space_generator/space_generator.cc | 3 +- .../{default_schedule.cc => trace_apply.cc} | 122 ++++++------------ .../{default_schedule.h => trace_apply.h} | 23 +++- src/meta_schedule/utils.h | 35 +++++ src/relay/backend/te_compiler_cache.cc | 7 +- src/tir/schedule/utils.h | 14 ++ 10 files changed, 122 insertions(+), 192 deletions(-) rename src/meta_schedule/{default_schedule.cc => trace_apply.cc} (66%) rename src/meta_schedule/{default_schedule.h => trace_apply.h} (53%) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2dc3140ac814..1b018512146f 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -245,6 +245,8 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Auto bind loops around the block to BlockIdx and ThreadIdx * \param max_threadblocks The maximum number of threadblock on GPU * \param thread_extents Candidates of thread axis extent. + * \param max_threads_per_block The maximum number of threads per block, if it is known + * when this schedule rule is created. * \return The schedule rule created */ TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents, diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 5e03b2ab60e1..9340b373d5b4 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -21,16 +21,11 @@ #include #include #include -#include -#include -#include +#include #include -#include #include "../node/ndarray_hash_equal.h" -#include "../tir/schedule/analysis.h" -#include "tvm/tir/analysis.h" namespace tvm { namespace meta_schedule { @@ -71,90 +66,6 @@ class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { } }; -tir::Stmt GetEncolsingLoop(const tir::BlockNode* block, tir::Stmt func_body) { - using namespace tir; - struct GetRootSeqStmt : public StmtVisitor { - void VisitStmt_(const SeqStmtNode* seq) override { - result = seq; - } - const SeqStmtNode* result; - }; - - struct BlockFinder : public StmtVisitor { - BlockFinder(const BlockNode* tgt) : target(tgt) {} - - void VisitStmt_(const BlockNode* block) override { - if (block == target) { - found = true; - } - } - - const BlockNode* target; - bool found = false; - }; - - GetRootSeqStmt seq_finder; - seq_finder(func_body); - - const SeqStmtNode* seq = seq_finder.result; - - for (auto stmt: seq->seq) { - if (stmt->IsInstance()) { - BlockFinder finder(block);; - finder(stmt); - if (finder.found) { - return stmt; - } - } - } - - LOG(FATAL) << "Enclosing loop not found for a block " << GetRef(block); - return Stmt(); -} - -std::optional GetAnchorBlock(IRModule mod) { - using namespace tir; - - struct BlockCollector : public StmtVisitor { - void VisitStmt_(const BlockNode* block) override { - blocks.push_back(block); - StmtVisitor::VisitStmt(block->body); - } - std::vector blocks; - }; - - auto prim_func = FindEntryFunc(mod, nullptr); - BlockCollector collector; - collector(prim_func->body); - - ICHECK(collector.blocks.size() > 0); - - std::vector candidates; - for (auto block : collector.blocks) { - if (block->init) { - candidates.push_back(block); - } - } - - if (candidates.empty()) { - return std::nullopt; - } else if (candidates.size() == 1) { - return candidates[0]; - } - - double best_flops = -1; - int best_idx = 0; - for (size_t i = 0; i < candidates.size(); ++i) { - auto loop = GetEncolsingLoop(candidates[i], prim_func->body); - auto flops = tir::EstimateTIRFlops(loop); - if (flops > best_flops) { - best_flops = flops; - best_idx = i; - } - } - return candidates[best_idx]; -} - class ModuleEqualityIgnoreNDArray : public ModuleEquality { public: size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); } @@ -165,18 +76,18 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { - auto anchor_block = GetAnchorBlock(mod); + auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return SHashHandlerIgnoreNDArray().Hash(GetRef(*anchor_block), false); + return SHashHandlerIgnoreNDArray().Hash(GetRef(anchor_block), false); } return ModuleEqualityIgnoreNDArray().Hash(mod); } bool Equal(IRModule lhs, IRModule rhs) const { - auto anchor_block_lhs = GetAnchorBlock(lhs); - auto anchor_block_rhs = GetAnchorBlock(rhs); + auto anchor_block_lhs = tir::FindAnchorBlock(lhs); + auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return SEqualHandlerIgnoreNDArray().Equal(GetRef(*anchor_block_lhs), - GetRef(*anchor_block_rhs), false); + return SEqualHandlerIgnoreNDArray().Equal(GetRef(anchor_block_lhs), + GetRef(anchor_block_rhs), false); } return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); } diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index dac7417efab6..8c99b563551b 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -20,11 +20,9 @@ #define TVM_META_SCHEDULE_MODULE_EQUALITY_H_ #include -#include #include #include -#include namespace tvm { namespace meta_schedule { @@ -71,8 +69,6 @@ class ModuleEqual { const ModuleEquality& mod_eq_; }; -std::optional GetAnchorBlock(IRModule mod); - } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 207d331384d6..bd492d03eac6 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include "../default_schedule.h" #include "../utils.h" namespace tvm { diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 2aba699c675f..bcc0673e5924 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -46,8 +46,7 @@ String GetRuleKindFromTarget(const Target& target) { return "cuda"; } - const std::unordered_set other_gpu_targets{"rocm", "vulkan", "metal"}; - if (other_gpu_targets.count(target->kind->name)) { + if (IsGPUTarget(target->kind->name)) { return "cuda"; } diff --git a/src/meta_schedule/default_schedule.cc b/src/meta_schedule/trace_apply.cc similarity index 66% rename from src/meta_schedule/default_schedule.cc rename to src/meta_schedule/trace_apply.cc index c8a520cd4663..9258cf74985f 100644 --- a/src/meta_schedule/default_schedule.cc +++ b/src/meta_schedule/trace_apply.cc @@ -16,60 +16,28 @@ * specific language governing permissions and limitations * under the License. */ -#include "default_schedule.h" +#include "trace_apply.h" -#include -#include -#include +#include #include -#include -#include #include +#include #include #include -#include "../printer/text_printer.h" -#include "../tir/schedule/analysis.h" -#include "../tir/schedule/utils.h" -#include "module_equality.h" +#include "utils.h" namespace tvm { namespace meta_schedule { -static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal"}; - -ScheduleRule GetDefaultAutoInline(const std::string& target_name) { - if (target_name == "llvm" || target_name == "hexagon") { - return ScheduleRule::AutoInline( - /*into_producer=*/false, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/true, - /*require_injective=*/true, - /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}); - } else if (gpu_targets.count(target_name)) { - return ScheduleRule::AutoInline( - /*into_producer=*/true, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/false, - /*require_injective=*/false, - /*require_ordered=*/false, - /*disallow_op=*/Array{}); - } - LOG(FATAL) << "Unsupported target " << target_name; - return ScheduleRule(nullptr); -} - -std::set GetBlockNames(const IRModule& mod) { +std::unordered_set GetBlockNames(const IRModule& mod) { struct BlockNameCollector : public tir::StmtVisitor { void VisitStmt_(const tir::BlockNode* block) override { block_names.insert(block->name_hint); StmtVisitor::VisitStmt(block->body); } - std::set block_names; + std::unordered_set block_names; }; auto prim_func = tir::FindEntryFunc(mod, nullptr); @@ -78,7 +46,9 @@ std::set GetBlockNames(const IRModule& mod) { return collector.block_names; } -bool IsAncestor(tir::BlockRV b1, tir::BlockRV b2, tir::Schedule sch) { +using namespace tir; + +bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { return true; } @@ -88,14 +58,9 @@ bool IsAncestor(tir::BlockRV b1, tir::BlockRV b2, tir::Schedule sch) { return false; } -std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_trace, - Target target) { - using namespace tir; - auto block_names_orig = GetBlockNames(sch->mod()); - const auto sch_orig = sch->Copy(); +void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { static auto kind_get_block = InstructionKind::Get("GetBlock"); - - std::set get_block_names; + std::unordered_set get_block_names; for (const auto& inst : anchor_trace->insts) { if (inst->kind.same_as(kind_get_block)) { auto block_name = Downcast(inst->attrs[0]); @@ -104,30 +69,33 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ } } - auto inline_rule = GetDefaultAutoInline(target->kind->name); - auto anchor_block = GetAnchorBlock(sch->mod()); + auto anchor_block = FindAnchorBlock(sch->mod()); std::optional anchor_block_rv = std::nullopt; if (anchor_block) { - anchor_block_rv = sch->GetBlock((*anchor_block)->name_hint); + anchor_block_rv = sch->GetBlock(anchor_block->name_hint); } - for (auto name : block_names_orig) { + auto inline_rule = GetDefaultAutoInline(target->kind->name); + + for (auto name : GetBlockNames(sch->mod())) { auto block = sch->GetBlock(name); if (anchor_block_rv && IsAncestor(block, *anchor_block_rv, sch)) continue; if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { - LOG(INFO) << "Inlining " << name; inline_rule->Apply(sch, block); } } +} - // LOG(INFO) << "After inlining ¥n" << tir::AsTVMScript(sch->mod()); - - std::unordered_map rv_map; +std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); + static auto kind_get_block = InstructionKind::Get("GetBlock"); + + const auto block_names_orig = GetBlockNames(sch->mod()); + const auto sch_orig = sch->Copy(); + std::unordered_map rv_map; std::unordered_set foreign_blocks; std::unordered_set foreign_loops; - std::set scheduled_blocks; auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { for (auto input : inst->inputs) { @@ -169,8 +137,6 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); continue; - } else { - scheduled_blocks.insert(block_name); } } @@ -194,12 +160,7 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ } } - const auto block_names_now = GetBlockNames(sch->mod()); - - auto is_scheduled = [=, &scheduled_blocks](const std::string& block_name) { - if (!block_names_now.count(block_name) || scheduled_blocks.count(block_name)) { - return true; - } + auto is_scheduled = [=](const std::string& block_name) { auto loops = sch->GetLoops(sch->GetBlock(block_name)); auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); if (loops.size() != loops_orig.size()) { @@ -215,10 +176,11 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ return false; }; + const auto block_names_now = GetBlockNames(sch->mod()); std::vector unscheduled_blocks; for (auto name : block_names_orig) { - if (!is_scheduled(name)) { + if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) { unscheduled_blocks.push_back(sch->GetBlock(name)); } } @@ -226,30 +188,30 @@ std::vector ApplyAnchorTrace(tir::Schedule sch, tir::Trace anchor_ return unscheduled_blocks; } -void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target) { - // LOG(INFO) << tir::AsTVMScript(sch->mod()); - // LOG(INFO) << anchor_trace; +void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) { + InlinePostBlocks(sch, anchor_trace, target); - auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace, target); - - // LOG(INFO) << tir::AsTVMScript(sch->mod()); - // LOG(INFO) << unscheduled_blocks.size(); + auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); ICHECK(unscheduled_blocks.size() <= 1); if (unscheduled_blocks.empty()) { // All blocks have already been scheduled. - // e.g. Applying a trace from conv2d -> add to conv2d -> subtract + // e.g. Applying a trace from conv2d -> add to + // - conv2d -> add -> add + // - conv2d -> subtract return; } - auto last_block_producers = sch->GetProducers(unscheduled_blocks[0]); - if (last_block_producers.size() == 1 && tir::IsSpatial(sch->GetSRef(last_block_producers[0]))) { + auto last_block = unscheduled_blocks[0]; + auto last_block_producers = sch->GetProducers(last_block); + + if (last_block_producers.size() == 1 && IsSpatial(sch->GetSRef(last_block_producers[0]))) { // Inline into the cache write stage - sch->ReverseComputeInline(unscheduled_blocks[0]); + sch->ReverseComputeInline(last_block); } else if (target->kind->name == "llvm" || target->kind->name == "hexagon") { - sch->Parallel(sch->Fuse(sch->GetLoops(unscheduled_blocks[0]))); - } else if (gpu_targets.count(target->kind->name)) { - Optional max_threads_per_block = target->GetAttr("max_threads_per_block"); + sch->Parallel(sch->Fuse(sch->GetLoops(last_block))); + } else if (IsGPUTarget(target->kind->name)) { + auto max_threads_per_block = target->GetAttr("max_threads_per_block"); ICHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; @@ -257,10 +219,8 @@ void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target ScheduleRule::AutoBind(/*max_threadblocks=*/256, /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}, max_threads_per_block.value()->value); - auto_bind_rule->Apply(sch, unscheduled_blocks[0]); + auto_bind_rule->Apply(sch, last_block); } - - // LOG(INFO) << tir::AsTVMScript(sch->mod()); } } // namespace meta_schedule diff --git a/src/meta_schedule/default_schedule.h b/src/meta_schedule/trace_apply.h similarity index 53% rename from src/meta_schedule/default_schedule.h rename to src/meta_schedule/trace_apply.h index c4e662aedcd6..9a9068ab914f 100644 --- a/src/meta_schedule/default_schedule.h +++ b/src/meta_schedule/trace_apply.h @@ -16,22 +16,33 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ -#define TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ +#ifndef TVM_META_SCHEDULE_TRACE_APPLY_H_ +#define TVM_META_SCHEDULE_TRACE_APPLY_H_ #include #include #include #include +#include + namespace tvm { namespace meta_schedule { -void ScheduleFusedBlocks(tir::Schedule sch, tir::Trace anchor_trace, tvm::Target target); - -ScheduleRule GetDefaultAutoInline(const std::string& target_name); +/*! + * \brief Apply the trace from a TIR module whose anchor block is the same but fused elemewise + * op blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add + * subgraph to other subgraphs having the same conv2d workload, for example. We call such trace + * an "anchor trace". Those blocks that are not scheduled by the given anchor trace will be either + * inlined or parallelized. + * \param sch The schedule to apply the anchor trace. + * \param anchor_trace The trace tuned on other subgraph with the same anchor-block workload. + * \param target The target information needed for inlining and parallelization. + */ +void ScheduleUsingAnchorTrace(tir::Schedule sch, const tir::Trace& anchor_trace, + const tvm::Target& target); } // namespace meta_schedule } // namespace tvm -#endif // TVM_META_SCHEDULE_DEFAULT_SCHEDULE_H_ +#endif // TVM_META_SCHEDULE_TRACE_APPLY_H_ diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 824cfcd6aa5c..949588168534 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -502,6 +502,41 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { } } +/*! \brief Returns true if the given target is one of the supported gpu targets. */ +inline bool IsGPUTarget(const std::string& target_name) { + static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal"}; + return gpu_targets.count(target_name); +} + +/*! + * \brief Create an AutoInline schedule rule for the given target. + * \param target_name The name of the target ("llvm", "cuda", etc.) + * \return The AutoInline schedule rule for the given target. + */ +inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { + if (target_name == "llvm" || target_name == "hexagon") { + return ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}); + } else if (IsGPUTarget(target_name)) { + return ScheduleRule::AutoInline( + /*into_producer=*/true, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/false, + /*require_injective=*/false, + /*require_ordered=*/false, + /*disallow_op=*/Array{}); + } + LOG(FATAL) << "Unsupported target " << target_name; + return ScheduleRule(nullptr); +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 7b9aca994a2a..c97efb565d9d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -53,8 +53,8 @@ #include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" -#include "../src/meta_schedule/default_schedule.h" #include "../src/meta_schedule/module_equality.h" +#include "../src/meta_schedule/trace_apply.h" #include "../transforms/meta_schedule_layout_rewrite.h" #include "utils.h" @@ -624,7 +624,10 @@ class ScheduleBuilder : public ExprVisitor { tir::ScheduleErrorRenderLevel::kDetail); if (!mod_eq_structural_->Equal(query_mod, opt_record.value()->workload->mod)) { - meta_schedule::ScheduleFusedBlocks(sch, record->trace, target_); + // When the database lookup succeeds while structural equality check fails, + // it implies that the anchor block based equality has been used during tuning. + // The trace in the record cannot directly be applied to this query module. + meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, target_); } else { record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index d2312ea41ec2..2065a64a6214 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -442,9 +442,23 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { } } +/******** Utilites for trace application ********/ + +/*! + * \brief Translate the input objects using the provided substitution map. + * \param inputs The input objects. + * \param rv_map The substitution map for variables. + * \return The transformed objects. + */ Array TranslateInputRVs(const Array& inputs, const std::unordered_map& rv_map); +/*! + * \brief Update the variable substitution map according to the new outputs. + * \param old_outputs The previous outputs of a schedule instruction. + * \param new_outputs The new outputs of the same schedule instruction. + * \param rv_map The substitution map for variables. + */ void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, std::unordered_map* rv_map); From fbe5160f9561e9ebec044bb571128870ee169f79 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Oct 2022 16:44:51 +0900 Subject: [PATCH 09/28] fix reverse compute inline unapplicable case --- src/meta_schedule/trace_apply.cc | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 9258cf74985f..2ad90b22465c 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -89,6 +89,8 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); static auto kind_get_block = InstructionKind::Get("GetBlock"); + static auto kind_compute_inline = InstructionKind::Get("ComputeInline"); + static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline"); const auto block_names_orig = GetBlockNames(sch->mod()); const auto sch_orig = sch->Copy(); @@ -120,6 +122,8 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { continue; } + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + if (inst->kind.same_as(kind_get_block)) { auto find_prefix_any = [&block_names_orig](const std::string& block_name) { for (auto name : block_names_orig) { @@ -138,9 +142,24 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { foreign_blocks.insert(block); continue; } + } else if (inst->kind.same_as(kind_reverse_compute_inline)) { + auto block = Downcast(inputs[0]); + auto block_sref = sch->GetSRef(block); + if (!CanReverseComputeInline(sch->state(), block_sref)) { + ICHECK(CanComputeInline(sch->state(), block_sref)); + sch->ComputeInline(block); + continue; + } + } else if (inst->kind.same_as(kind_compute_inline)) { + auto block = Downcast(inputs[0]); + auto block_sref = sch->GetSRef(block); + if (!CanComputeInline(sch->state(), block_sref)) { + ICHECK(CanReverseComputeInline(sch->state(), block_sref)); + sch->ReverseComputeInline(block); + continue; + } } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); Optional decision = anchor_trace->GetDecision(inst); Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); From cf4d8b7702686d05257de5c3814f0e460c82e7f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Oct 2022 17:58:24 +0900 Subject: [PATCH 10/28] fixed get_block applicablity condition --- src/meta_schedule/trace_apply.cc | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 2ad90b22465c..02c6352a93a8 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -125,19 +125,10 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { Array inputs = TranslateInputRVs(inst->inputs, rv_map); if (inst->kind.same_as(kind_get_block)) { - auto find_prefix_any = [&block_names_orig](const std::string& block_name) { - for (auto name : block_names_orig) { - if (block_name.find(name) == 0) { - return true; - } - } - return false; - }; - auto block_name = Downcast(inst->attrs[0]); - ICHECK(block_name.defined()); + auto block_names_current = GetBlockNames(sch->mod()); - if (!find_prefix_any(block_name)) { + if (!block_names_current.count(block_name)) { auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); continue; @@ -208,6 +199,8 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) { + // LOG(INFO) << AsTVMScript(sch->mod()); + // LOG(INFO) << anchor_trace; InlinePostBlocks(sch, anchor_trace, target); auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); @@ -218,6 +211,7 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm // e.g. Applying a trace from conv2d -> add to // - conv2d -> add -> add // - conv2d -> subtract + // LOG(INFO) << "All scheduled " << AsTVMScript(sch->mod()); return; } @@ -240,6 +234,8 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm max_threads_per_block.value()->value); auto_bind_rule->Apply(sch, last_block); } + + // LOG(INFO) << AsTVMScript(sch->mod()); } } // namespace meta_schedule From 8ee4da617e27be0ce76b12b74fdcd4aa7c8c5c48 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 25 Oct 2022 08:56:41 +0900 Subject: [PATCH 11/28] adding test --- src/meta_schedule/trace_apply.cc | 8 -------- .../test_meta_schedule_trace_apply.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) create mode 100644 tests/python/unittest/test_meta_schedule_trace_apply.py diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 02c6352a93a8..332d0c1b6b36 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -199,8 +199,6 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) { - // LOG(INFO) << AsTVMScript(sch->mod()); - // LOG(INFO) << anchor_trace; InlinePostBlocks(sch, anchor_trace, target); auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); @@ -208,10 +206,6 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm if (unscheduled_blocks.empty()) { // All blocks have already been scheduled. - // e.g. Applying a trace from conv2d -> add to - // - conv2d -> add -> add - // - conv2d -> subtract - // LOG(INFO) << "All scheduled " << AsTVMScript(sch->mod()); return; } @@ -234,8 +228,6 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm max_threads_per_block.value()->value); auto_bind_rule->Apply(sch, last_block); } - - // LOG(INFO) << AsTVMScript(sch->mod()); } } // namespace meta_schedule diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py new file mode 100644 index 000000000000..98ceac4486c7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm From 51b376675a54378443bf059dab9d538b1a79293b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 25 Oct 2022 17:17:01 +0900 Subject: [PATCH 12/28] introduce HasBlock utility --- python/tvm/tir/schedule/analysis.py | 4 +++ src/meta_schedule/trace_apply.cc | 28 +++---------------- src/tir/schedule/analysis/analysis.cc | 2 ++ src/tir/schedule/utils.h | 20 +++++++++++++ .../test_meta_schedule_vnni_integration.py | 3 +- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 90c585ac8ce1..b132e4059821 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -122,3 +122,7 @@ def get_auto_tensorize_mapping_info( intrinsics. """ return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore + + +def has_block(sch, block_name): + return _ffi_api.HasBlock(sch, block_name); diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 332d0c1b6b36..da63e205c89a 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -31,21 +31,6 @@ namespace tvm { namespace meta_schedule { -std::unordered_set GetBlockNames(const IRModule& mod) { - struct BlockNameCollector : public tir::StmtVisitor { - void VisitStmt_(const tir::BlockNode* block) override { - block_names.insert(block->name_hint); - StmtVisitor::VisitStmt(block->body); - } - std::unordered_set block_names; - }; - - auto prim_func = tir::FindEntryFunc(mod, nullptr); - BlockNameCollector collector; - collector(prim_func->body); - return collector.block_names; -} - using namespace tir; bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { @@ -124,15 +109,10 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { Array inputs = TranslateInputRVs(inst->inputs, rv_map); - if (inst->kind.same_as(kind_get_block)) { - auto block_name = Downcast(inst->attrs[0]); - auto block_names_current = GetBlockNames(sch->mod()); - - if (!block_names_current.count(block_name)) { - auto block = Downcast(inst->outputs[0]); - foreign_blocks.insert(block); - continue; - } + if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast(inst->attrs[0]))) { + auto block = Downcast(inst->outputs[0]); + foreign_blocks.insert(block); + continue; } else if (inst->kind.same_as(kind_reverse_compute_inline)) { auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d8b4f31f4c1b..64cc8013d716 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2059,5 +2059,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }); +TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2065a64a6214..e7363bd20a34 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -442,6 +442,26 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { } } +inline std::unordered_set GetBlockNames(const IRModule& mod) { + struct BlockNameCollector : public tir::StmtVisitor { + void VisitStmt_(const tir::BlockNode* block) override { + block_names.insert(block->name_hint); + StmtVisitor::VisitStmt(block->body); + } + std::unordered_set block_names; + }; + + auto prim_func = tir::FindEntryFunc(mod, nullptr); + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; +} + +inline bool HasBlock(const Schedule& sch, const std::string& block_name) { + auto block_names = GetBlockNames(sch->mod()); + return block_names.count(block_name); +} + /******** Utilites for trace application ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index d0bfc913eca6..527dc46ad2db 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -26,6 +26,7 @@ from tvm import relay from tvm._ffi import register_func from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN logging.basicConfig( @@ -43,7 +44,7 @@ def _schedule_dense(m: Optional[int], do_tune: bool): def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if sch.mod.attrs is not None and "dense" not in sch.mod.attrs["task_name"]: return False - if dense_block is None: + if dense_block is None and has_block(sch, "compute"): dense_block = sch.get_block("compute") assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] From 289cd9a2830177ee4c2f5c09b9569cf51d678d2b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 04:05:04 +0900 Subject: [PATCH 13/28] Decoupled trace creation and application in Trace::ApplyJSONToschedule --- include/tvm/tir/schedule/trace.h | 8 +++ python/tvm/tir/schedule/trace.py | 11 ++++ src/tir/schedule/trace.cc | 106 ++++++++++++++++++++----------- 3 files changed, 88 insertions(+), 37 deletions(-) diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index b6b3b57226c8..7ed573d12139 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -148,6 +148,14 @@ class Trace : public runtime::ObjectRef { * \param decisions The decisions made in sampling */ explicit Trace(Array insts, Map decisions); + + /*! + * \brief Restore a trace from its serialized representation + * \param json The JSON-serialized trace + * \return The restored trace + */ + static Trace FromJSON(ObjectRef json); + /*! * \brief Apply a JSON-serialized trace to a TensorIR schedule * \param json The JSON-serialized trace diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index da599081df3b..228db9fea8a0 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -231,6 +231,17 @@ def with_decision( remove_postproc, ) + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Trace": + """Apply a JSON-serialized trace to a TensorIR schedule + + Parameters + ---------- + json_obj : JSON_TYPE + The JSON-serialized tracenn + """ + return _ffi_api.TraceFromJSON(json_obj) # type: ignore # pylint: disable=no-member + def simplified(self, remove_postproc: bool) -> "Trace": """Simplify the trace with dead-code elimination diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index b90b6b85960f..02fa513f4180 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -135,35 +135,35 @@ Array TranslateInputRVs(const Array& inputs, Array results; results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { - // Case 3. integer or floating-point number if (input->IsInstance() || input->IsInstance()) { + // Case 3. integer or floating-point number results.push_back(input); - continue; - } - // Case 4. array - if (input->IsInstance()) { + } else if (input->IsInstance()) { + // Case 4. array results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); - continue; - } - // Case 5. dict - if (input->IsInstance()) { + } else if (input->IsInstance()) { + // Case 5. dict + results.push_back(input); + } else if (input->IsInstance()) { + const auto* str = input.as(); + CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; + const char* name = str->data; + int64_t size = str->size; + if (auto it = named_rvs.find(name); it != named_rvs.end()) { + // Case 0 & 1. None, BlockRV, LoopRV, VarRV + results.push_back(it->second); + } else { + // Case 2. string + if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { + results.push_back(String(std::string(name + 1, size - 2))); + } else { + // strings without quotation + results.push_back(input); + } + } + } else { results.push_back(input); - continue; - } - const auto* str = input.as(); - CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); - CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; - const char* name = str->data; - int64_t size = str->size; - // Case 2. string - if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { - results.push_back(String(std::string(name + 1, size - 2))); - continue; } - // Case 0 & 1. None, BlockRV, LoopRV, VarRV - auto it = named_rvs.find(name); - CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; - results.push_back(it->second); } return results; } @@ -255,18 +255,40 @@ void TraceNode::ApplyToSchedule( const Optional& decision)> decision_provider) const { std::unordered_map rv_map; + std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; + + auto all_string = [](const Array& objs) { + for (auto obj : objs) { + if (!obj->IsInstance()) { + return false; + } + } + return true; + }; + for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); + Array inputs = TranslateInputRVs(TranslateInputRVs(inst->inputs, named_rvs), rv_map); + Array attrs = inst->attrs; Optional decision = this->GetDecision(inst); if (decision_provider != nullptr) { decision = decision_provider(inst, inputs, attrs, decision); } Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); - TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + + if (all_string(inst->outputs)) { + Array outputs_str; + for (auto str : inst->outputs) { + outputs_str.push_back(Downcast(str)); + } + TranslateAddOutputRVs(outputs_str, outputs, &named_rvs); + TranslateAddOutputRVs(outputs, outputs, &rv_map); + } else { + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } } } @@ -330,7 +352,7 @@ Array TraceNode::AsPython(bool remove_postproc) const { return py_trace; } -void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { +Trace Trace::FromJSON(ObjectRef json) { Array json_insts{nullptr}; Array json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` @@ -369,13 +391,15 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { decisions[index] = std::move(decision); } // Parse `json_insts` - std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; int i = 0; + Array instructions; + Map decisions_map; + for (const ObjectRef& inst_entry : json_insts) { InstructionKind kind{nullptr}; Array inputs{nullptr}; Array attrs{nullptr}; - Array outputs{ObjectPtr{nullptr}}; + Array outputs_str{ObjectPtr{nullptr}}; // Parse the entry try { const auto* arr = inst_entry.as(); @@ -391,25 +415,33 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { kind = InstructionKind::Get(arr0->data); inputs = GetRef>(arr1); attrs = GetRef>(arr2); - outputs = GetRef>(arr3); + outputs_str = GetRef>(arr3); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " "inputs, attrs, outputs], but gets: " << inst_entry; throw; } - // Parse inputs - inputs = TranslateInputRVs(inputs, named_rvs); // Parse attrs if (kind->f_attrs_from_json != nullptr) { attrs = kind->f_attrs_from_json(attrs); } - // Apply to the schedule - Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); - // Parse outputs - TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); + + Array outputs; + for (auto str : outputs_str) { + outputs.push_back(str); + } + instructions.push_back(Instruction(kind, inputs, attrs, outputs)); + if (decisions[i]) { + decisions_map.Set(instructions.back(), decisions[i].value()); + } ++i; } + return Trace(instructions, decisions_map); +} + +void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { + FromJSON(json)->ApplyToSchedule(sch, false); } /**************** Creation ****************/ @@ -550,6 +582,6 @@ TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision") TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") .set_body_typed(Trace::ApplyJSONToSchedule); - +TVM_REGISTER_GLOBAL("tir.schedule.TraceFromJSON").set_body_typed(Trace::FromJSON); } // namespace tir } // namespace tvm From 5c0d47fd69ddc34d372f92264157b56404abaebe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 05:44:16 +0900 Subject: [PATCH 14/28] add test --- python/tvm/meta_schedule/__init__.py | 1 + src/meta_schedule/trace_apply.cc | 3 + .../test_meta_schedule_trace_apply.py | 158 ++++++++++++++++++ .../test_meta_schedule_vnni_integration.py | 3 +- 4 files changed, 164 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 04acdc9d4a75..0dd679e047e0 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -30,6 +30,7 @@ search_strategy, space_generator, tir_integration, + trace_apply, ) from .builder import Builder from .cost_model import CostModel diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index da63e205c89a..a1a8ec61c6da 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -210,5 +210,8 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm } } +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") + .set_body_typed(ScheduleUsingAnchorTrace); + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 98ceac4486c7..4e02416171b2 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -17,3 +17,161 @@ import pytest import tvm +import tvm.meta_schedule as ms +from tvm.script import tir as T +from tvm.tir import Schedule, floormod, floordiv +from tvm.target import Target + + +@tvm.script.ir_module +class Dense: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_matmul_NT: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("T_matmul_NT"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(p0[i, k], p1[j, k]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + T_matmul_NT[i, j] = T.float32(0) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] + + +@tvm.script.ir_module +class DenseAdd: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") + compile_engine_const = T.alloc_buffer([], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("T_matmul_NT"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(p0[i, k], p1[j, k]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + T_matmul_NT[i, j] = T.float32(0) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(1) + for i0, i1 in T.grid(128, 128): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_matmul_NT[ax0, ax1], compile_engine_const[()]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + compile_engine_const[()] + + +@tvm.script.ir_module +class DenseAdd_scheduled: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT_global = T.alloc_buffer([128, 128], dtype="float32") + p1_global = T.alloc_buffer([2, 128, 64], dtype="float32") + for ax0, ax1 in T.grid(128, 128): + with T.block("p1_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(p1[v0, v1]) + T.writes(p1_global[v0 // 64, v1, v0 % 64]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":1}) + p1_global[v0 // 64, v1, v0 % 64] = p1[v0, v1] + for i0_0_i1_0_fused_fused in T.parallel(4): + for i0_1, i1_1 in T.grid(8, 1): + for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 1, 2): + for i1_3_fused_init in T.vectorized(64): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2_init * 2 + i0_3_init) + j = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2_init * 64 + i1_3_fused_init) + T.reads() + T.writes(T_matmul_NT_global[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT_global[i, j] = T.float32(0) + for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(128, 4, 1, 1, 2): + for i1_3_fused in T.vectorized(64): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2 * 2 + i0_3) + j = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2 * 64 + i1_3_fused) + k = T.axis.reduce(128, i2_0 + i2_1) + T.reads(T_matmul_NT_global[i, j], p0[i, k], p1_global[j // 64, k, j % 64]) + T.writes(T_matmul_NT_global[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + p0[i, k] * p1_global[j // 64, k, j % 64] + for ax0 in T.serial(64): + for ax1_fused in T.vectorized(64): + with T.block("T_matmul_NT_global"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + ax1_fused) + T.reads(T_matmul_NT_global[v0, v1]) + T.writes(T_add[v0, v1]) + T_add[v0, v1] = T_matmul_NT_global[v0, v1] + T.float32(1) + + +def test_dense_add_cpu(): + def apply_anchor_trace(sch: Schedule) -> None: + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2]) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64]) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) + sch.enter_postproc() + b27 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") + b28, b29 = sch.get_child_blocks(b27) + l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) + l40 = sch.fuse(l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l40) + l41 = sch.fuse(l39, preserve_unit_iters=True) + sch.vectorize(loop=l41) + l42, l43, l44 = sch.get_loops(block=b29) + l45 = sch.fuse(l42, preserve_unit_iters=True) + sch.parallel(loop=l45) + l46 = sch.fuse(l44, preserve_unit_iters=True) + sch.vectorize(loop=l46) + b47 = sch.get_block(name="T_matmul_NT", func_name="main") + l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) + b57 = sch.decompose_reduction(block=b47, loop=l51) + b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") + sch.transform_layout(block=b58, buffer=("read", 2), index_map=tvm.tir.IndexMap.from_func(lambda i0, i1: (floordiv(i0, 64), i1, floormod(i0, 64),), inverse_index_map=lambda i0, i1, i2: (((i0*64) + i2), i1,)), pad_value=None) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + + anchor_sch = Schedule(Dense) + apply_anchor_trace(anchor_sch) + trace = anchor_sch.trace + + sch = Schedule(DenseAdd) + target = Target("llvm") + + ms.trace_apply.schedule_using_anchor_trace(sch, trace, target) + + tvm.ir.assert_structural_equal(DenseAdd_scheduled, sch.mod) diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index 527dc46ad2db..1f91dc593143 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -44,7 +44,8 @@ def _schedule_dense(m: Optional[int], do_tune: bool): def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if sch.mod.attrs is not None and "dense" not in sch.mod.attrs["task_name"]: return False - if dense_block is None and has_block(sch, "compute"): + if dense_block is None: + assert has_block(sch, "compute") dense_block = sch.get_block("compute") assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] From fbb2361856c2eb522fedfc90a5c457f55666ad10 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 06:18:23 +0900 Subject: [PATCH 15/28] adding more test --- .../test_meta_schedule_trace_apply.py | 156 +++++++++++++++++- 1 file changed, 148 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 4e02416171b2..7f2357a77088 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -17,12 +17,26 @@ import pytest import tvm +import tvm.testing import tvm.meta_schedule as ms from tvm.script import tir as T from tvm.tir import Schedule, floormod, floordiv from tvm.target import Target +def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): + anchor_sch = Schedule(anchor_mod) + anchor_trace_fun(anchor_sch) + trace = anchor_sch.trace + + sch = Schedule(target_mod) + + ms.trace_apply.schedule_using_anchor_trace(sch, trace, Target(target)) + + # print(sch.mod.script()) + tvm.ir.assert_structural_equal(ref, sch.mod) + + @tvm.script.ir_module class Dense: @T.prim_func @@ -75,7 +89,7 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" @tvm.script.ir_module -class DenseAdd_scheduled: +class DenseAdd_scheduled_cpu: @T.prim_func def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: # function attr dict @@ -165,13 +179,139 @@ def apply_anchor_trace(sch: Schedule) -> None: sch.transform_layout(block=b58, buffer=("read", 2), index_map=tvm.tir.IndexMap.from_func(lambda i0, i1: (floordiv(i0, 64), i1, floormod(i0, 64),), inverse_index_map=lambda i0, i1, i2: (((i0*64) + i2), i1,)), pad_value=None) sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) - anchor_sch = Schedule(Dense) - apply_anchor_trace(anchor_sch) - trace = anchor_sch.trace + verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) + + +@tvm.script.ir_module +class DenseAdd_scheduled_gpu: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT_local = T.alloc_buffer([128, 128], dtype="float32", scope="local") + p0_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + p1_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for i0_1_i1_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(1, 4, 1, 1): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3_init + i0_4_init) + j = T.axis.spatial(128, i1_4_init + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + i1_3_init) + T.reads() + T.writes(T_matmul_NT_local[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":256, "meta_schedule.thread_extent_low_inclusive":16, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T_matmul_NT_local[i, j] = T.float32(0) + for i2_0 in T.serial(32): + for ax0_ax1_fused_0 in T.serial(1): + for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(2): + with T.block("p0_shared"): + T.where((ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 2 + ax0_ax1_fused_2 < 64) + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 4) + v1 = T.axis.spatial(128, i2_0 * 4 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 4) + T.reads(p0[v0, v1]) + T.writes(p0_shared[v0, v1]) + p0_shared[v0, v1] = p0[v0, v1] + for ax0_ax1_fused_0 in T.serial(1): + for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(4): + with T.block("p1_shared"): + T.where((ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2 < 128) + v0 = T.axis.spatial(128, i0_0_i1_0_fused % 4 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 4) + v1 = T.axis.spatial(128, i2_0 * 4 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 4) + T.reads(p1[v0, v1]) + T.writes(p1_shared[v0, v1]) + p1_shared[v0, v1] = p1[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(1, 1, 4, 4, 1, 1): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3 + i0_4) + j = T.axis.spatial(128, i1_4 + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + i1_3) + k = T.axis.reduce(128, i2_0 * 4 + i2_1 * 4 + i2_2) + T.reads(T_matmul_NT_local[i, j], p0_shared[i, k], p1_shared[j, k]) + T.writes(T_matmul_NT_local[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":256, "meta_schedule.thread_extent_low_inclusive":16, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T_matmul_NT_local[i, j] = T_matmul_NT_local[i, j] + p0_shared[i, k] * p1_shared[j, k] + for ax0, ax1 in T.grid(1, 4): + with T.block("T_matmul_NT_local"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + ax1) + T.reads(T_matmul_NT_local[v0, v1]) + T.writes(T_add[v0, v1]) + T_add[v0, v1] = T_matmul_NT_local[v0, v1] + T.float32(1) + + +def test_dense_add_gpu(): + def apply_anchor_trace(sch: Schedule) -> None: + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9], preserve_unit_iters=True) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19], preserve_unit_iters=True) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[32, 1, 4]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27], preserve_unit_iters=True) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20, preserve_unit_iters=True) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21, preserve_unit_iters=True) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22, preserve_unit_iters=True) + sch.bind(loop=l33, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=16) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b34 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b34, loop=l33, preserve_unit_loops=True, index=-1) + b35 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b35, loop=l28, preserve_unit_loops=True, index=-1) + l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35) + l42 = sch.fuse(l40, l41, preserve_unit_iters=True) + v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) + sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True, index=-1) + l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50, preserve_unit_iters=True) + v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52) + v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53) + sch.enter_postproc() + sch.unannotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch") + l54, l55, l56, l57, l58 = sch.get_loops(block=b35) + l59, l60, l61 = sch.split(loop=l58, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l61) + sch.bind(loop=l60, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch") + l62, l63, l64, l65, l66 = sch.get_loops(block=b44) + l67, l68, l69 = sch.split(loop=l66, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l69) + sch.bind(loop=l68, thread_axis="threadIdx.x") + b70 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b70, ann_key="meta_schedule.unroll_explicit") + b71, b72, b73, b74 = sch.get_child_blocks(b70) + l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71) + sch.annotate(block_or_loop=l75, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l75, ann_key="pragma_unroll_explicit", ann_val=1) + l82, l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b72) + sch.annotate(block_or_loop=l82, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l82, ann_key="pragma_unroll_explicit", ann_val=1) + l89, l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b73) + sch.annotate(block_or_loop=l89, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l89, ann_key="pragma_unroll_explicit", ann_val=1) + l99, l100, l101, l102, l103 = sch.get_loops(block=b74) + sch.annotate(block_or_loop=l99, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l99, ann_key="pragma_unroll_explicit", ann_val=1) + b104 = sch.get_block(name="T_matmul_NT", func_name="main") + l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) + b115 = sch.decompose_reduction(block=b104, loop=l108) - sch = Schedule(DenseAdd) - target = Target("llvm") + verify(Dense, apply_anchor_trace, DenseAdd, "metal", DenseAdd_scheduled_gpu) - ms.trace_apply.schedule_using_anchor_trace(sch, trace, target) - tvm.ir.assert_structural_equal(DenseAdd_scheduled, sch.mod) +if __name__ == "__main__": + tvm.testing.main() From 0b48c14762a33d101026940067beaff9dd9ff0a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 06:19:53 +0900 Subject: [PATCH 16/28] black --- python/tvm/tir/schedule/analysis.py | 2 +- .../test_meta_schedule_relay_integration.py | 5 +- .../test_meta_schedule_trace_apply.py | 461 +++++++++++++----- 3 files changed, 330 insertions(+), 138 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index b132e4059821..5a4e5840ead0 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -125,4 +125,4 @@ def get_auto_tensorize_mapping_info( def has_block(sch, block_name): - return _ffi_api.HasBlock(sch, block_name); + return _ffi_api.HasBlock(sch, block_name) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index ddec3dc6f75c..1818177eada9 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -106,8 +106,9 @@ def test_meta_schedule_integration_extract_from_resnet(): for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name - extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params, - module_equality="anchor-block") + extracted_tasks = ms.relay_integration.extract_tasks( + mod, target="llvm", params=params, module_equality="anchor-block" + ) assert len(extracted_tasks) == 16 diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 7f2357a77088..6ea4e36f8008 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -40,7 +40,11 @@ def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): @tvm.script.ir_module class Dense: @T.prim_func - def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_matmul_NT: T.Buffer[(128, 128), "float32"]) -> None: + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_matmul_NT: T.Buffer[(128, 128), "float32"], + ) -> None: # function attr dict T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) # body @@ -50,7 +54,7 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(p0[i, k], p1[j, k]) T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders":[]}) + T.block_attr({"layout_free_placeholders": []}) with T.init(): T_matmul_NT[i, j] = T.float32(0) T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] @@ -59,7 +63,11 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" @tvm.script.ir_module class DenseAdd: @T.prim_func - def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body @@ -71,7 +79,7 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(p0[i, k], p1[j, k]) T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders":[]}) + T.block_attr({"layout_free_placeholders": []}) with T.init(): T_matmul_NT[i, j] = T.float32(0) T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] @@ -91,7 +99,11 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" @tvm.script.ir_module class DenseAdd_scheduled_cpu: @T.prim_func - def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body @@ -103,29 +115,60 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(p1[v0, v1]) T.writes(p1_global[v0 // 64, v1, v0 % 64]) - T.block_attr({"meta_schedule.layout_rewrite_preproc":1}) + T.block_attr({"meta_schedule.layout_rewrite_preproc": 1}) p1_global[v0 // 64, v1, v0 % 64] = p1[v0, v1] for i0_0_i1_0_fused_fused in T.parallel(4): for i0_1, i1_1 in T.grid(8, 1): for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 1, 2): for i1_3_fused_init in T.vectorized(64): with T.block("T_matmul_NT_init"): - i = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2_init * 2 + i0_3_init) - j = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2_init * 64 + i1_3_fused_init) + i = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused // 2 * 64 + + i0_1 * 8 + + i0_2_init * 2 + + i0_3_init, + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused % 2 * 64 + + i1_1 * 64 + + i1_2_init * 64 + + i1_3_fused_init, + ) T.reads() T.writes(T_matmul_NT_global[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.tiling_structure": "SSRSRS", + } + ) T_matmul_NT_global[i, j] = T.float32(0) for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(128, 4, 1, 1, 2): for i1_3_fused in T.vectorized(64): with T.block("T_matmul_NT_update"): - i = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2 * 2 + i0_3) - j = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2 * 64 + i1_3_fused) + i = T.axis.spatial( + 128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2 * 2 + i0_3 + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused_fused % 2 * 64 + i1_1 * 64 + i1_2 * 64 + i1_3_fused, + ) k = T.axis.reduce(128, i2_0 + i2_1) - T.reads(T_matmul_NT_global[i, j], p0[i, k], p1_global[j // 64, k, j % 64]) + T.reads( + T_matmul_NT_global[i, j], p0[i, k], p1_global[j // 64, k, j % 64] + ) T.writes(T_matmul_NT_global[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) - T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + p0[i, k] * p1_global[j // 64, k, j % 64] + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.tiling_structure": "SSRSRS", + } + ) + T_matmul_NT_global[i, j] = ( + T_matmul_NT_global[i, j] + p0[i, k] * p1_global[j // 64, k, j % 64] + ) for ax0 in T.serial(64): for ax1_fused in T.vectorized(64): with T.block("T_matmul_NT_global"): @@ -138,46 +181,69 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" def test_dense_add_cpu(): def apply_anchor_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") - l2, l3, l4 = sch.get_loops(block=b0) - v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2]) - l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) - v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64]) - l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True) - v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) - l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) - sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) - b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") - sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) - v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) - sch.enter_postproc() - b27 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") - b28, b29 = sch.get_child_blocks(b27) - l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) - l40 = sch.fuse(l30, l31, preserve_unit_iters=True) - sch.parallel(loop=l40) - l41 = sch.fuse(l39, preserve_unit_iters=True) - sch.vectorize(loop=l41) - l42, l43, l44 = sch.get_loops(block=b29) - l45 = sch.fuse(l42, preserve_unit_iters=True) - sch.parallel(loop=l45) - l46 = sch.fuse(l44, preserve_unit_iters=True) - sch.vectorize(loop=l46) - b47 = sch.get_block(name="T_matmul_NT", func_name="main") - l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) - b57 = sch.decompose_reduction(block=b47, loop=l51) - b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") - b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") - sch.transform_layout(block=b58, buffer=("read", 2), index_map=tvm.tir.IndexMap.from_func(lambda i0, i1: (floordiv(i0, 64), i1, floormod(i0, 64),), inverse_index_map=lambda i0, i1, i2: (((i0*64) + i2), i1,)), pad_value=None) - sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2] + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64] + ) + l17, l18, l19, l20 = sch.split( + loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True + ) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v26 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) + sch.enter_postproc() + b27 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") + b28, b29 = sch.get_child_blocks(b27) + l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) + l40 = sch.fuse(l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l40) + l41 = sch.fuse(l39, preserve_unit_iters=True) + sch.vectorize(loop=l41) + l42, l43, l44 = sch.get_loops(block=b29) + l45 = sch.fuse(l42, preserve_unit_iters=True) + sch.parallel(loop=l45) + l46 = sch.fuse(l44, preserve_unit_iters=True) + sch.vectorize(loop=l46) + b47 = sch.get_block(name="T_matmul_NT", func_name="main") + l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) + b57 = sch.decompose_reduction(block=b47, loop=l51) + b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") + sch.transform_layout( + block=b58, + buffer=("read", 2), + index_map=tvm.tir.IndexMap.from_func( + lambda i0, i1: ( + floordiv(i0, 64), + i1, + floormod(i0, 64), + ), + inverse_index_map=lambda i0, i1, i2: ( + ((i0 * 64) + i2), + i1, + ), + ), + pad_value=None, + ) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) @@ -185,7 +251,11 @@ def apply_anchor_trace(sch: Schedule) -> None: @tvm.script.ir_module class DenseAdd_scheduled_gpu: @T.prim_func - def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + def main( + p0: T.Buffer[(128, 128), "float32"], + p1: T.Buffer[(128, 128), "float32"], + T_add: T.Buffer[(128, 128), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body @@ -193,25 +263,70 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" T_matmul_NT_local = T.alloc_buffer([128, 128], dtype="float32", scope="local") p0_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") p1_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - for i0_0_i1_0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for i0_0_i1_0_fused in T.thread_binding( + 32, + thread="blockIdx.x", + annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}, + ): for i0_1_i1_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(1, 4, 1, 1): with T.block("T_matmul_NT_init"): - i = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3_init + i0_4_init) - j = T.axis.spatial(128, i1_4_init + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + i1_3_init) + i = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + + i0_2_i1_2_fused // 8 + + i0_3_init + + i0_4_init, + ) + j = T.axis.spatial( + 128, + i1_4_init + + i0_0_i1_0_fused % 4 * 32 + + i0_2_i1_2_fused % 8 * 4 + + i1_3_init, + ) T.reads() T.writes(T_matmul_NT_local[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":256, "meta_schedule.thread_extent_low_inclusive":16, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.thread_extent_high_inclusive": 256, + "meta_schedule.thread_extent_low_inclusive": 16, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) T_matmul_NT_local[i, j] = T.float32(0) for i2_0 in T.serial(32): for ax0_ax1_fused_0 in T.serial(1): for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(2): with T.block("p0_shared"): - T.where((ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 2 + ax0_ax1_fused_2 < 64) - v0 = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 4) - v1 = T.axis.spatial(128, i2_0 * 4 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 4) + T.where( + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 2 + + ax0_ax1_fused_2 + < 64 + ) + v0 = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 2 + + ax0_ax1_fused_2 + ) + // 4, + ) + v1 = T.axis.spatial( + 128, + i2_0 * 4 + + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 2 + + ax0_ax1_fused_2 + ) + % 4, + ) T.reads(p0[v0, v1]) T.writes(p0_shared[v0, v1]) p0_shared[v0, v1] = p0[v0, v1] @@ -219,25 +334,69 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): with T.block("p1_shared"): - T.where((ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2 < 128) - v0 = T.axis.spatial(128, i0_0_i1_0_fused % 4 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 4) - v1 = T.axis.spatial(128, i2_0 * 4 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 4) + T.where( + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 4 + + ax0_ax1_fused_2 + < 128 + ) + v0 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 4 * 32 + + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 4 + + ax0_ax1_fused_2 + ) + // 4, + ) + v1 = T.axis.spatial( + 128, + i2_0 * 4 + + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 4 + + ax0_ax1_fused_2 + ) + % 4, + ) T.reads(p1[v0, v1]) T.writes(p1_shared[v0, v1]) p1_shared[v0, v1] = p1[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(1, 1, 4, 4, 1, 1): with T.block("T_matmul_NT_update"): - i = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3 + i0_4) - j = T.axis.spatial(128, i1_4 + i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + i1_3) + i = T.axis.spatial( + 128, + i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3 + i0_4, + ) + j = T.axis.spatial( + 128, + i1_4 + + i0_0_i1_0_fused % 4 * 32 + + i0_2_i1_2_fused % 8 * 4 + + i1_3, + ) k = T.axis.reduce(128, i2_0 * 4 + i2_1 * 4 + i2_2) T.reads(T_matmul_NT_local[i, j], p0_shared[i, k], p1_shared[j, k]) T.writes(T_matmul_NT_local[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":256, "meta_schedule.thread_extent_low_inclusive":16, "meta_schedule.tiling_structure":"SSSRRSRS"}) - T_matmul_NT_local[i, j] = T_matmul_NT_local[i, j] + p0_shared[i, k] * p1_shared[j, k] + T.block_attr( + { + "layout_free_placeholders": [], + "meta_schedule.thread_extent_high_inclusive": 256, + "meta_schedule.thread_extent_low_inclusive": 16, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + T_matmul_NT_local[i, j] = ( + T_matmul_NT_local[i, j] + p0_shared[i, k] * p1_shared[j, k] + ) for ax0, ax1 in T.grid(1, 4): with T.block("T_matmul_NT_local"): - v0 = T.axis.spatial(128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + ax0) - v1 = T.axis.spatial(128, i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + ax1) + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + ax0 + ) + v1 = T.axis.spatial( + 128, i0_0_i1_0_fused % 4 * 32 + i0_2_i1_2_fused % 8 * 4 + ax1 + ) T.reads(T_matmul_NT_local[v0, v1]) T.writes(T_add[v0, v1]) T_add[v0, v1] = T_matmul_NT_local[v0, v1] + T.float32(1) @@ -245,70 +404,102 @@ def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32" def test_dense_add_gpu(): def apply_anchor_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4 = sch.get_loops(block=b0) - v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) - l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9], preserve_unit_iters=True) - v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1]) - l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19], preserve_unit_iters=True) - v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[32, 1, 4]) - l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27], preserve_unit_iters=True) - sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) - l31 = sch.fuse(l10, l20, preserve_unit_iters=True) - sch.bind(loop=l31, thread_axis="blockIdx.x") - l32 = sch.fuse(l11, l21, preserve_unit_iters=True) - sch.bind(loop=l32, thread_axis="vthread.x") - l33 = sch.fuse(l12, l22, preserve_unit_iters=True) - sch.bind(loop=l33, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=16) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b34 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b34, loop=l33, preserve_unit_loops=True, index=-1) - b35 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b35, loop=l28, preserve_unit_loops=True, index=-1) - l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35) - l42 = sch.fuse(l40, l41, preserve_unit_iters=True) - v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) - b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True, index=-1) - l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44) - l51 = sch.fuse(l49, l50, preserve_unit_iters=True) - v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52) - v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53) - sch.enter_postproc() - sch.unannotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch") - l54, l55, l56, l57, l58 = sch.get_loops(block=b35) - l59, l60, l61 = sch.split(loop=l58, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l61) - sch.bind(loop=l60, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch") - l62, l63, l64, l65, l66 = sch.get_loops(block=b44) - l67, l68, l69 = sch.split(loop=l66, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l69) - sch.bind(loop=l68, thread_axis="threadIdx.x") - b70 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b70, ann_key="meta_schedule.unroll_explicit") - b71, b72, b73, b74 = sch.get_child_blocks(b70) - l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71) - sch.annotate(block_or_loop=l75, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l75, ann_key="pragma_unroll_explicit", ann_val=1) - l82, l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b72) - sch.annotate(block_or_loop=l82, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l82, ann_key="pragma_unroll_explicit", ann_val=1) - l89, l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b73) - sch.annotate(block_or_loop=l89, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l89, ann_key="pragma_unroll_explicit", ann_val=1) - l99, l100, l101, l102, l103 = sch.get_loops(block=b74) - sch.annotate(block_or_loop=l99, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l99, ann_key="pragma_unroll_explicit", ann_val=1) - b104 = sch.get_block(name="T_matmul_NT", func_name="main") - l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) - b115 = sch.decompose_reduction(block=b104, loop=l108) + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile( + loop=l2, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1] + ) + l10, l11, l12, l13, l14 = sch.split( + loop=l2, factors=[v5, v6, v7, v8, v9], preserve_unit_iters=True + ) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile( + loop=l3, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1] + ) + l20, l21, l22, l23, l24 = sch.split( + loop=l3, factors=[v15, v16, v17, v18, v19], preserve_unit_iters=True + ) + v25, v26, v27 = sch.sample_perfect_tile( + loop=l4, n=3, max_innermost_factor=64, decision=[32, 1, 4] + ) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27], preserve_unit_iters=True) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20, preserve_unit_iters=True) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21, preserve_unit_iters=True) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22, preserve_unit_iters=True) + sch.bind(loop=l33, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=16 + ) + sch.annotate( + block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256 + ) + b34 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b34, loop=l33, preserve_unit_loops=True, index=-1) + b35 = sch.cache_read( + block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0] + ) + sch.compute_at(block=b35, loop=l28, preserve_unit_loops=True, index=-1) + l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35) + l42 = sch.fuse(l40, l41, preserve_unit_iters=True) + v43 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read( + block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0] + ) + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True, index=-1) + l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50, preserve_unit_iters=True) + v52 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52) + v53 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=2, + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53) + sch.enter_postproc() + sch.unannotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch") + l54, l55, l56, l57, l58 = sch.get_loops(block=b35) + l59, l60, l61 = sch.split(loop=l58, factors=[None, 128, 2], preserve_unit_iters=True) + sch.vectorize(loop=l61) + sch.bind(loop=l60, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch") + l62, l63, l64, l65, l66 = sch.get_loops(block=b44) + l67, l68, l69 = sch.split(loop=l66, factors=[None, 128, 4], preserve_unit_iters=True) + sch.vectorize(loop=l69) + sch.bind(loop=l68, thread_axis="threadIdx.x") + b70 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b70, ann_key="meta_schedule.unroll_explicit") + b71, b72, b73, b74 = sch.get_child_blocks(b70) + l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71) + sch.annotate(block_or_loop=l75, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l75, ann_key="pragma_unroll_explicit", ann_val=1) + l82, l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b72) + sch.annotate(block_or_loop=l82, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l82, ann_key="pragma_unroll_explicit", ann_val=1) + l89, l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b73) + sch.annotate(block_or_loop=l89, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l89, ann_key="pragma_unroll_explicit", ann_val=1) + l99, l100, l101, l102, l103 = sch.get_loops(block=b74) + sch.annotate(block_or_loop=l99, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l99, ann_key="pragma_unroll_explicit", ann_val=1) + b104 = sch.get_block(name="T_matmul_NT", func_name="main") + l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) + b115 = sch.decompose_reduction(block=b104, loop=l108) verify(Dense, apply_anchor_trace, DenseAdd, "metal", DenseAdd_scheduled_gpu) From ac24ea30df3cd83cc11cbca228d1c2036912397b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 06:20:34 +0900 Subject: [PATCH 17/28] Revert "Decoupled trace creation and application in Trace::ApplyJSONToschedule" This reverts commit 02df571bff58064927659f6e81e1d35279826825. --- include/tvm/tir/schedule/trace.h | 8 --- python/tvm/tir/schedule/trace.py | 11 ---- src/tir/schedule/trace.cc | 106 +++++++++++-------------------- 3 files changed, 37 insertions(+), 88 deletions(-) diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index 7ed573d12139..b6b3b57226c8 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -148,14 +148,6 @@ class Trace : public runtime::ObjectRef { * \param decisions The decisions made in sampling */ explicit Trace(Array insts, Map decisions); - - /*! - * \brief Restore a trace from its serialized representation - * \param json The JSON-serialized trace - * \return The restored trace - */ - static Trace FromJSON(ObjectRef json); - /*! * \brief Apply a JSON-serialized trace to a TensorIR schedule * \param json The JSON-serialized trace diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index 228db9fea8a0..da599081df3b 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -231,17 +231,6 @@ def with_decision( remove_postproc, ) - @staticmethod - def from_json(json_obj: JSON_TYPE) -> "Trace": - """Apply a JSON-serialized trace to a TensorIR schedule - - Parameters - ---------- - json_obj : JSON_TYPE - The JSON-serialized tracenn - """ - return _ffi_api.TraceFromJSON(json_obj) # type: ignore # pylint: disable=no-member - def simplified(self, remove_postproc: bool) -> "Trace": """Simplify the trace with dead-code elimination diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 02fa513f4180..b90b6b85960f 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -135,35 +135,35 @@ Array TranslateInputRVs(const Array& inputs, Array results; results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { + // Case 3. integer or floating-point number if (input->IsInstance() || input->IsInstance()) { - // Case 3. integer or floating-point number results.push_back(input); - } else if (input->IsInstance()) { - // Case 4. array + continue; + } + // Case 4. array + if (input->IsInstance()) { results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); - } else if (input->IsInstance()) { - // Case 5. dict - results.push_back(input); - } else if (input->IsInstance()) { - const auto* str = input.as(); - CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; - const char* name = str->data; - int64_t size = str->size; - if (auto it = named_rvs.find(name); it != named_rvs.end()) { - // Case 0 & 1. None, BlockRV, LoopRV, VarRV - results.push_back(it->second); - } else { - // Case 2. string - if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { - results.push_back(String(std::string(name + 1, size - 2))); - } else { - // strings without quotation - results.push_back(input); - } - } - } else { + continue; + } + // Case 5. dict + if (input->IsInstance()) { results.push_back(input); + continue; } + const auto* str = input.as(); + CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); + CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; + const char* name = str->data; + int64_t size = str->size; + // Case 2. string + if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { + results.push_back(String(std::string(name + 1, size - 2))); + continue; + } + // Case 0 & 1. None, BlockRV, LoopRV, VarRV + auto it = named_rvs.find(name); + CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; + results.push_back(it->second); } return results; } @@ -255,40 +255,18 @@ void TraceNode::ApplyToSchedule( const Optional& decision)> decision_provider) const { std::unordered_map rv_map; - std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; - - auto all_string = [](const Array& objs) { - for (auto obj : objs) { - if (!obj->IsInstance()) { - return false; - } - } - return true; - }; - for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array inputs = TranslateInputRVs(TranslateInputRVs(inst->inputs, named_rvs), rv_map); - + Array inputs = TranslateInputRVs(inst->inputs, rv_map); Array attrs = inst->attrs; Optional decision = this->GetDecision(inst); if (decision_provider != nullptr) { decision = decision_provider(inst, inputs, attrs, decision); } Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); - - if (all_string(inst->outputs)) { - Array outputs_str; - for (auto str : inst->outputs) { - outputs_str.push_back(Downcast(str)); - } - TranslateAddOutputRVs(outputs_str, outputs, &named_rvs); - TranslateAddOutputRVs(outputs, outputs, &rv_map); - } else { - TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); - } + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } } @@ -352,7 +330,7 @@ Array TraceNode::AsPython(bool remove_postproc) const { return py_trace; } -Trace Trace::FromJSON(ObjectRef json) { +void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array json_insts{nullptr}; Array json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` @@ -391,15 +369,13 @@ Trace Trace::FromJSON(ObjectRef json) { decisions[index] = std::move(decision); } // Parse `json_insts` + std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; int i = 0; - Array instructions; - Map decisions_map; - for (const ObjectRef& inst_entry : json_insts) { InstructionKind kind{nullptr}; Array inputs{nullptr}; Array attrs{nullptr}; - Array outputs_str{ObjectPtr{nullptr}}; + Array outputs{ObjectPtr{nullptr}}; // Parse the entry try { const auto* arr = inst_entry.as(); @@ -415,33 +391,25 @@ Trace Trace::FromJSON(ObjectRef json) { kind = InstructionKind::Get(arr0->data); inputs = GetRef>(arr1); attrs = GetRef>(arr2); - outputs_str = GetRef>(arr3); + outputs = GetRef>(arr3); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " "inputs, attrs, outputs], but gets: " << inst_entry; throw; } + // Parse inputs + inputs = TranslateInputRVs(inputs, named_rvs); // Parse attrs if (kind->f_attrs_from_json != nullptr) { attrs = kind->f_attrs_from_json(attrs); } - - Array outputs; - for (auto str : outputs_str) { - outputs.push_back(str); - } - instructions.push_back(Instruction(kind, inputs, attrs, outputs)); - if (decisions[i]) { - decisions_map.Set(instructions.back(), decisions[i].value()); - } + // Apply to the schedule + Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); + // Parse outputs + TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); ++i; } - return Trace(instructions, decisions_map); -} - -void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { - FromJSON(json)->ApplyToSchedule(sch, false); } /**************** Creation ****************/ @@ -582,6 +550,6 @@ TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision") TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") .set_body_typed(Trace::ApplyJSONToSchedule); -TVM_REGISTER_GLOBAL("tir.schedule.TraceFromJSON").set_body_typed(Trace::FromJSON); + } // namespace tir } // namespace tvm From 200a2a5a9375f5ded562a86e4efa4e305e6caf2c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 06:22:22 +0900 Subject: [PATCH 18/28] add tests --- python/tvm/meta_schedule/trace_apply.py | 22 + python/tvm/script/tir/__init__.py | 5 +- python/tvm/script/tir/ty.py | 2 +- .../test_meta_schedule_trace_apply.py | 2399 ++++++++++++++++- 4 files changed, 2344 insertions(+), 84 deletions(-) create mode 100644 python/tvm/meta_schedule/trace_apply.py diff --git a/python/tvm/meta_schedule/trace_apply.py b/python/tvm/meta_schedule/trace_apply.py new file mode 100644 index 000000000000..364a84344684 --- /dev/null +++ b/python/tvm/meta_schedule/trace_apply.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TODO""" +from . import _ffi_api + + +def schedule_using_anchor_trace(sch, anchor_trace, target): + return _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index d7db182f9d20..662dd10ec068 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -25,8 +25,9 @@ # add all floating point and integer datatypes to the module for _dtype in ["float", "uint", "int"]: for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32"]: + for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: from . import ty _name = _dtype + _size + _lanes - globals()[_name] = getattr(ty, _name) + if hasattr(ty, _name): + globals()[_name] = getattr(ty, _name) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index b8323dd4a167..b17b571e88e7 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -202,7 +202,7 @@ def __getitem__(self, args): # add all floating point and integer datatypes to the module for _dtype in ["float", "uint", "int"]: for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32"]: + for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: _name = _dtype + _size + _lanes globals()[_name] = ConcreteType(_name) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 6ea4e36f8008..6ff21c72c9ea 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -21,22 +21,12 @@ import tvm.meta_schedule as ms from tvm.script import tir as T from tvm.tir import Schedule, floormod, floordiv +from tvm.tir.tensor_intrin.cuda import * from tvm.target import Target +from tvm.target.codegen import llvm_lookup_intrinsic_id -def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): - anchor_sch = Schedule(anchor_mod) - anchor_trace_fun(anchor_sch) - trace = anchor_sch.trace - - sch = Schedule(target_mod) - - ms.trace_apply.schedule_using_anchor_trace(sch, trace, Target(target)) - - # print(sch.mod.script()) - tvm.ir.assert_structural_equal(ref, sch.mod) - - +# fmt: off @tvm.script.ir_module class Dense: @T.prim_func @@ -179,73 +169,50 @@ def main( T_add[v0, v1] = T_matmul_NT_global[v0, v1] + T.float32(1) -def test_dense_add_cpu(): - def apply_anchor_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") - l2, l3, l4 = sch.get_loops(block=b0) - v5, v6, v7, v8 = sch.sample_perfect_tile( - loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2] - ) - l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) - v13, v14, v15, v16 = sch.sample_perfect_tile( - loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64] - ) - l17, l18, l19, l20 = sch.split( - loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True - ) - v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) - l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) - sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) - b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") - sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) - v26 = sch.sample_categorical( - candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 - ) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) - sch.enter_postproc() - b27 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") - sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") - b28, b29 = sch.get_child_blocks(b27) - l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) - l40 = sch.fuse(l30, l31, preserve_unit_iters=True) - sch.parallel(loop=l40) - l41 = sch.fuse(l39, preserve_unit_iters=True) - sch.vectorize(loop=l41) - l42, l43, l44 = sch.get_loops(block=b29) - l45 = sch.fuse(l42, preserve_unit_iters=True) - sch.parallel(loop=l45) - l46 = sch.fuse(l44, preserve_unit_iters=True) - sch.vectorize(loop=l46) - b47 = sch.get_block(name="T_matmul_NT", func_name="main") - l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) - b57 = sch.decompose_reduction(block=b47, loop=l51) - b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") - b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") - sch.transform_layout( - block=b58, - buffer=("read", 2), - index_map=tvm.tir.IndexMap.from_func( - lambda i0, i1: ( - floordiv(i0, 64), - i1, - floormod(i0, 64), - ), - inverse_index_map=lambda i0, i1, i2: ( - ((i0 * 64) + i2), - i1, - ), - ), - pad_value=None, - ) - sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) - - verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) +@tvm.script.ir_module +class DenseAdd_cpu_no_write_cache: + @T.prim_func + def main(p0: T.Buffer[(128, 128), "float32"], p1: T.Buffer[(128, 128), "float32"], T_add: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") + p1_global = T.alloc_buffer([8, 4, 16, 32], dtype="float32") + for ax0, ax1 in T.grid(128, 128): + with T.block("p1_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(p1[v0, v1]) + T.writes(p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":1}) + p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32] = p1[v0, v1] + for i0_0_i1_0_i0_1_i1_1_fused in T.parallel(16, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 4, 2): + for i1_3_fused_init in T.vectorized(32): + with T.block("T_matmul_NT_init"): + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2_init * 2 + i0_3_init) + j = T.axis.spatial(128, i1_2_init * 32 + i1_3_fused_init) + T.reads() + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT[i, j] = T.float32(0) + for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(8, 4, 4, 16, 2): + for i1_3_fused in T.vectorized(32): + with T.block("T_matmul_NT_update"): + i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused // 4 * 32 + i0_0_i1_0_i0_1_i1_1_fused % 4 * 8 + i0_2 * 2 + i0_3) + j = T.axis.spatial(128, i1_2 * 32 + i1_3_fused) + k = T.axis.reduce(128, i2_0 * 16 + i2_1) + T.reads(T_matmul_NT[i, j], p0[i, k], p1_global[k // 16, j // 32, k % 16, j % 32]) + T.writes(T_matmul_NT[i, j]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1_global[k // 16, j // 32, k % 16, j % 32] + for i0_i1_fused in T.parallel(16384): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0_i1_fused // 128) + ax1 = T.axis.spatial(128, i0_i1_fused % 128) + T.reads(T_matmul_NT[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T_matmul_NT[ax0, ax1] + T.float32(1) @tvm.script.ir_module @@ -402,6 +369,1312 @@ def main( T_add[v0, v1] = T_matmul_NT_local[v0, v1] + T.float32(1) +@tvm.script.ir_module +class Conv2dInt8: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], compute: T.Buffer[(16, 56, 56, 256), "int32"]) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_multiply = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_add_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_right_shift = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_cast_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_2 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + T_cast_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_cast[ax0, ax1, ax2, ax3]) + T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_right_shift"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) + T.writes(T_right_shift[ax0, ax1, ax2, ax3]) + T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_right_shift[ax0, ax1, ax2, ax3]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3]) + T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3]) + T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3]) + T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") + for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): + with T.block("compute_1"): + i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) + T.writes(compute[i0_7, i1_7, i2_7, i3_7]) + compute[i0_7, i1_7, i2_7, i3_7] = T.q_multiply_shift(T_subtract_1[i0_7, i1_7, i2_7, i3_7], 1963325822, 31, 1, dtype="int32") + + +@tvm.script.ir_module +class Conv2dInt8_target: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") + conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_multiply = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_add_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_right_shift = T.alloc_buffer([16, 56, 56, 256], dtype="int64") + T_cast_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_2 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + T_cast_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_add_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + T_cast_4 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = 0 + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_cast[ax0, ax1, ax2, ax3]) + T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_right_shift"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) + T.writes(T_right_shift[ax0, ax1, ax2, ax3]) + T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_right_shift[ax0, ax1, ax2, ax3]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3]) + T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_2[ax0, ax1, ax2, ax3]) + T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): + with T.block("compute"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) + compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3]) + T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") + for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3]) + T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") + for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): + with T.block("T_subtract_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) + T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) + T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] + for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): + with T.block("compute_1"): + i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) + T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) + T.writes(compute_2[i0_7, i1_7, i2_7, i3_7]) + compute_2[i0_7, i1_7, i2_7, i3_7] = T.q_multiply_shift(T_subtract_1[i0_7, i1_7, i2_7, i3_7], 1098990753, 31, 1, dtype="int32") + for i0_8, i1_8, i2_8, i3_8 in T.grid(16, 56, 56, 256): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_8, i1_8, i2_8, i3_8]) + T.reads(compute_2[ax0, ax1, ax2, ax3], p9[ax0, ax1, ax2, ax3]) + T.writes(T_add_3[ax0, ax1, ax2, ax3]) + T_add_3[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] + p9[ax0, ax1, ax2, ax3] + for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 56, 56, 256): + with T.block("compute_2"): + i0_10, i1_10, i2_10, i3_10 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) + T.reads(T_add_3[i0_10, i1_10, i2_10, i3_10]) + T.writes(compute_3[i0_10, i1_10, i2_10, i3_10]) + compute_3[i0_10, i1_10, i2_10, i3_10] = T.max(T.min(T_add_3[i0_10, i1_10, i2_10, i3_10], 255), 0) + for i0_11, i1_11, i2_11, i3_11 in T.grid(16, 56, 56, 256): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_11, i1_11, i2_11, i3_11]) + T.reads(compute_3[ax0, ax1, ax2, ax3]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3]) + T_cast_4[ax0, ax1, ax2, ax3] = T.cast(compute_3[ax0, ax1, ax2, ax3], "uint8") + for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): + with T.block("compute_3"): + i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) + T.reads(T_cast_4[i0_13, i1_13, i2_13, i3_13]) + T.writes(compute[i0_13, i1_13, i2_13, i3_13]) + compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(T_cast_4[i0_13, i1_13, i2_13, i3_13], T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_tensorcore_scheduled: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + a0 = T.var("int32") + a1 = T.var("int32") + b0 = T.var("int32") + b1 = T.var("int32") + c0 = T.var("int32") + c1 = T.var("int32") + d0 = T.var("int32") + d0_1 = T.var("int32") + d0_2 = T.var("int32") + d0_3 = T.var("int32") + d1 = T.var("int32") + d1_1 = T.var("int32") + d1_2 = T.var("int32") + d1_3 = T.var("int32") + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + # body + # with T.block("root") + conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") + p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0_0, ax1_0 in T.grid(1, 1): + for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): + with T.block("conv2d_nhwc_o_init"): + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) + v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[d1, d0], scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // d1 // 16 * (d1 // 16) + C.elem_offset % d1 // 16, T.float32(0), dtype="handle")) + for ax4_0_0 in T.serial(2): + for ax0_ax1_fused_0 in T.serial(16): + for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(16): + with T.block("pad_temp_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) + T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(pad_temp_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]}) + pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): + with T.block("p1_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32) + v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(p1[v2, v0, v1, v3]) + T.writes(p1_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]}) + p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): + for ax0_0_1, ax1_0_1 in T.grid(1, 2): + with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) + T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[s1, s0], scope="shared", offset_factor=16) + C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_1, d0_1], scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // d1_1 // 16 * (d1_1 // 16) + C_1.elem_offset % d1_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): + with T.block("p1_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[s1_1, s0_1], scope="shared", offset_factor=16) + C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_2, d0_2], scope="wmma.matrix_b", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // d1_2 // 16 * (d1_2 // 16) + C_2.elem_offset % d1_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "col_major", dtype="handle")) + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): + with T.block("conv2d_nhwc_o_update"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[a1, a0], scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[b1, b0], scope="wmma.matrix_b", offset_factor=16) + C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[c1, c0], scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, A_2.data, A_2.elem_offset // a1 // 16 * (a1 // 16) + A_2.elem_offset % a1 // 16, B.data, B.elem_offset // b1 // 16 * (b1 // 16) + B.elem_offset % b1 // 16, C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, dtype="handle")) + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[d1_3, d0_3], scope="wmma.accumulator", offset_factor=16) + C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[s1_2, s0_2], scope="shared", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // d1_3 // 16 * (d1_3 // 16) + A_3.elem_offset % d1_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle")) + for ax0, ax1_0 in T.grid(128, 2): + for ax1_1 in T.thread_binding(16, thread="threadIdx.x"): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + ax0) + v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + ax1_0 * 16 + ax1_1) + T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[0, 0, 0, v1], p5[0, 0, 0, v1], p6[0, 0, 0, v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(p7[()] + T.cast(T.shift_right(T.cast(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], "int64") * p4[0, 0, 0, v1] + p5[0, 0, 0, v1], p6[0, 0, 0, v1], dtype="int64"), "int32"), 255), 0), "uint8"), "int32") - p8[0], 1098990753, 31, 1, dtype="int32") + p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 255), 0), "uint8"), T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_NCHWc: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "int32"], compute: T.Buffer[(1, 128, 7, 7, 16), "uint8"]) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="float32") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_1 = T.alloc_buffer([], dtype="float32") + T_add_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + T_cast_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_subtract = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_2 = T.alloc_buffer([], dtype="float32") + T_add_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_5 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(0.94537687301635742) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): + with T.block("conv2d_NCHWc_int8"): + n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) + T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) + T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) + T_cast[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) + T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) + T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] + with T.block("compile_engine_const_1"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_1[()]) + compile_engine_const_1[()] = T.float32(54.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) + T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) + T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) + T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1]) + T.writes(compute_1[i0_1, i1_1, i2_1, i3_1, i4_1]) + compute_1[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_2[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4], p4[0]) + T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) + T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_3[ax0, ax1, ax2, ax3, ax4] - p4[0] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) + T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_2"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_2[()]) + compile_engine_const_2[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) + T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) + T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T_add_3[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] + p5[ax0, ax1, ax2, ax3, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2]) + T.writes(compute_2[i0_2, i1_2, i2_2, i3_2, i4_2]) + compute_2[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) + T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) + T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_2[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_2"): + i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) + T.reads(T_cast_5[i0_5, i1_5, i2_5, i3_5, i4_5]) + T.writes(compute[i0_5, i1_5, i2_5, i3_5, i4_5]) + compute[i0_5, i1_5, i2_5, i3_5, i4_5] = T.max(T.min(T_cast_5[i0_5, i1_5, i2_5, i3_5, i4_5], T.uint8(255)), T.uint8(0)) + + +@tvm.script.ir_module +class Conv2dInt8_NCHWc_target: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "uint8"], T_cast: T.Buffer[(1, 128, 7, 7, 16), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compile_engine_const = T.alloc_buffer([], dtype="float32") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_1 = T.alloc_buffer([], dtype="float32") + T_add_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + T_cast_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_subtract = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_2 = T.alloc_buffer([], dtype="float32") + T_add_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_5 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compile_engine_const_3 = T.alloc_buffer([], dtype="float32") + T_cast_6 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_multiply_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + compile_engine_const_4 = T.alloc_buffer([], dtype="float32") + T_add_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_floor_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="float32") + T_cast_7 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_add_4 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + compute_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + T_cast_8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") + with T.block("compile_engine_const"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const[()]) + compile_engine_const[()] = T.float32(0.95489668846130371) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): + with T.block("conv2d_NCHWc_int8"): + n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) + T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) + T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) + T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_1[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) + T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) + T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast_1[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] + with T.block("compile_engine_const_1"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_1[()]) + compile_engine_const_1[()] = T.float32(65.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) + T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) + T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) + T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1, i4_1]) + compute[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compute[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(compute[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) + T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_3[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p4[0]) + T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) + T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] - p4[0] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) + T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_2"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_2[()]) + compile_engine_const_2[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) + T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) + T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") + with T.block("compile_engine_const_3"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_3[()]) + compile_engine_const_3[()] = T.float32(0.71245479583740234) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(p5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_6[ax0, ax1, ax2, ax3, ax4]) + T_cast_6[ax0, ax1, ax2, ax3, ax4] = T.cast(p5[ax0, ax1, ax2, ax3, ax4], "float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_multiply_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(compile_engine_const_3[()], T_cast_6[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) + T_multiply_2[ax0, ax1, ax2, ax3, ax4] = compile_engine_const_3[()] * T_cast_6[ax0, ax1, ax2, ax3, ax4] + with T.block("compile_engine_const_4"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(compile_engine_const_4[()]) + compile_engine_const_4[()] = T.float32(0.5) + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3, ax4], compile_engine_const_4[()]) + T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T_add_3[ax0, ax1, ax2, ax3, ax4] = T_multiply_2[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_4[()] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_floor_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_floor_2[ax0, ax1, ax2, ax3, ax4]) + T_floor_2[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_6"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_floor_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_7[ax0, ax1, ax2, ax3, ax4]) + T_cast_7[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_2[ax0, ax1, ax2, ax3, ax4], "int32") + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("T_add_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_cast_5[ax0, ax1, ax2, ax3, ax4], T_cast_7[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_add_4[ax0, ax1, ax2, ax3, ax4]) + T_add_4[ax0, ax1, ax2, ax3, ax4] = T_cast_5[ax0, ax1, ax2, ax3, ax4] + T_cast_7[ax0, ax1, ax2, ax3, ax4] + for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2]) + T.writes(compute_1[i0_2, i1_2, i2_2, i3_2, i4_2]) + compute_1[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) + for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_7"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) + T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast_8[ax0, ax1, ax2, ax3, ax4]) + T_cast_8[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") + for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): + with T.block("compute_2"): + i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) + T.reads(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5]) + T.writes(compute_2[i0_5, i1_5, i2_5, i3_5, i4_5]) + compute_2[i0_5, i1_5, i2_5, i3_5, i4_5] = T.max(T.min(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5], T.uint8(255)), T.uint8(0)) + for i0_6, i1_6, i2_6, i3_6, i4_6 in T.grid(1, 128, 7, 7, 16): + with T.block("T_cast_8"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_6, i1_6, i2_6, i3_6, i4_6]) + T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) + T_cast[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_2[ax0, ax1, ax2, ax3, ax4], "int32") + + +def get_conv2d_vnni_mod(intrin_id): + @tvm.script.ir_module + class Conv2dInt8_NCHWc_scheduled: + @T.prim_func + def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1, 4, 16, 4), "int8"], p2: T.Buffer[(1, 128, 1, 1, 16), "int32"], p3: T.Buffer[(1, 128, 1, 1, 16), "float32"], p4: T.Buffer[1, "float32"], p5: T.Buffer[(1, 128, 7, 7, 16), "uint8"], T_cast: T.Buffer[(1, 128, 7, 7, 16), "int32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") + for i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused in T.parallel(128, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1): + for i5_0, i6_0 in T.grid(1, 1): + for i1_2_init, i2_2_init, i3_2_init, i1_3_init, i2_3_init, i3_3_init in T.grid(1, 1, 1, 1, 1, 7): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + for i4_1 in T.vectorized(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32) + oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) + ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_1 + i8_0) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1) + B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1) + C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x64: T.int8x64 = B[0, 0:64] + B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + C[0:16] = C[0:16] + T.call_llvm_pure_intrin(intrin_id, T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): + for ax4_fused in T.vectorized(16): + with T.block("T_cast_8"): + ax0_1 = T.axis.spatial(1, ax0) + ax1_1 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32 + ax1) + ax2_1 = T.axis.spatial(7, i2_1 + ax2) + ax3_1, ax4 = T.axis.remap("SS", [ax3, ax4_fused]) + T.reads(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4], p2[ax0_1, ax1_1, 0, 0, ax4], p3[ax0_1, ax1_1, 0, 0, ax4], p4[0], p5[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T.writes(T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = T.cast(T.max(T.min(T.cast(T.max(T.min(T.cast(T.floor(T.float32(0.95489668846130371) * (T.cast(T.cast(T.max(T.min(T.cast(T.floor(T.cast(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4] + p2[ax0_1, ax1_1, 0, 0, ax4], "float32") * p3[ax0_1, ax1_1, 0, 0, ax4] + T.float32(65.5), dtype="float32"), "int32"), 255), 0), "uint8"), "float32") - p4[0]) + T.float32(0.5), dtype="float32"), "int32") + T.cast(T.floor(T.float32(0.71245479583740234) * T.cast(p5[ax0_1, ax1_1, ax2_1, ax3_1, ax4], "float32") + T.float32(0.5), dtype="float32"), "int32"), 255), 0), "uint8"), T.uint8(255)), T.uint8(0)), "int32") + + return Conv2dInt8_NCHWc_scheduled + + +@tvm.script.ir_module +class Conv2dWinogradAddRelu: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + B = T.alloc_buffer([6, 6], dtype="float32") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + A = T.alloc_buffer([6, 4], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + conv2d_winograd = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + T.block_attr({"schedule_rule":"None"}) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") + for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) + T.writes(input_tile[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] + for i0, i1 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): + with T.block("data_pack"): + eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + with T.init(): + data_pack[eps, nu, p, ci] = T.float32(0) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] + for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): + with T.block("bgemm"): + eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) + T.writes(bgemm[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + bgemm[eps, nu, p, co] = T.float32(0) + bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] + for i0, i1 in T.grid(6, 4): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): + with T.block("inverse"): + vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("conv2d_winograd"): + n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add[ax0, ax1, ax2, ax3], T.float32(0)) + + +@tvm.script.ir_module +class Conv2dWinogradAddResidualRelu: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], p3: T.Buffer[(1, 56, 56, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + B = T.alloc_buffer([6, 6], dtype="float32") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + A = T.alloc_buffer([6, 4], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + conv2d_winograd = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + T_add_1 = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + T.block_attr({"schedule_rule":"None"}) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") + for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) + T.writes(input_tile[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] + for i0, i1 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): + with T.block("data_pack"): + eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + with T.init(): + data_pack[eps, nu, p, ci] = T.float32(0) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] + for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): + with T.block("bgemm"): + eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) + T.writes(bgemm[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[]}) + with T.init(): + bgemm[eps, nu, p, co] = T.float32(0) + bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] + for i0, i1 in T.grid(6, 4): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): + with T.block("inverse"): + vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("conv2d_winograd"): + n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3], p3[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_add[ax0, ax1, ax2, ax3] + p3[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add_1[ax0, ax1, ax2, ax3], T.float32(0)) + + +@tvm.script.ir_module +class Conv2dWinogradAddResidualRelu_scheduled: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(6, 6, 64, 64), "float32"], p2: T.Buffer[(1, 1, 1, 64), "float32"], p3: T.Buffer[(1, 56, 56, 64), "float32"], T_relu: T.Buffer[(1, 56, 56, 64), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + input_tile_local = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") + data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") + inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") + bgemm_local = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") + data_pack_shared = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="shared") + p1_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(98, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("input_tile"): + eps, nu = T.axis.remap("SS", [ax0, ax1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8 + ax2) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) + T.reads(p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci]) + T.writes(input_tile_local[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile_local[eps, nu, p, ci] = T.if_then_else(1 <= p % 196 // 14 * 4 + eps and p % 196 // 14 * 4 + eps < 57 and 1 <= p % 14 * 4 + nu and p % 14 * 4 + nu < 57, p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci], T.float32(0), dtype="float32") + for i0 in T.unroll(6): + for i1 in T.unroll(6): + with T.block("data_pack_init"): + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + T.reads() + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + data_pack[eps, nu, p, ci] = T.float32(0) + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack_update"): + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) + ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(data_pack[eps, nu, p, ci], input_tile_local[r_a, r_b, p, ci]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile_local[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(168, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(48, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(1, 1, 14, 1, 1, 1, 1, 1): + with T.block("bgemm_init"): + eps = T.axis.spatial(6, i0_4_init + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3_init) + nu = T.axis.spatial(6, i1_4_init + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3_init) + p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3_init + i2_4_init) + co = T.axis.spatial(64, i3_4_init + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3_init) + T.reads() + T.writes(bgemm_local[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + bgemm_local[eps, nu, p, co] = T.float32(0) + for i4_0 in T.serial(2): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(28): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 896) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) + v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 896 // 32) + v3 = T.axis.spatial(64, i4_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): + with T.block("p1_shared"): + v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 512) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) + v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 512 // 32) + v3 = T.axis.spatial(64, i4_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 32) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_shared[v0, v1, v2, v3]) + p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(2, 1, 1, 14, 1, 16, 1, 1, 1, 1): + with T.block("bgemm_update"): + eps = T.axis.spatial(6, i0_4 + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3) + nu = T.axis.spatial(6, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3) + p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3 + i2_4) + co = T.axis.spatial(64, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3) + ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 16 + i4_2) + T.reads(bgemm_local[eps, nu, p, co], data_pack_shared[eps, nu, p, ci], p1_shared[eps, nu, co, ci]) + T.writes(bgemm_local[eps, nu, p, co]) + T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + bgemm_local[eps, nu, p, co] = bgemm_local[eps, nu, p, co] + data_pack_shared[eps, nu, p, ci] * p1_shared[eps, nu, co, ci] + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): + with T.block("bgemm_local"): + v0 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + ax0) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + ax1) + v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + ax2) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + ax3) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(25, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(512, thread="threadIdx.x"): + for i0 in T.unroll(4): + for i1 in T.unroll(4): + with T.block("inverse_init"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) + co = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 448 // 224 * 32 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32) + T.reads() + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + inverse[vh, vw, p, co] = T.float32(0) + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("inverse_update"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) + co = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 448 // 224 * 32 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(inverse[vh, vw, p, co], bgemm[r_a, r_b, p, co]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_i1_i2_i3_fused_0 in T.thread_binding(1568, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): + for i0_i1_i2_i3_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("conv2d_winograd"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) // 3584) + w = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 3584 // 64) + co = T.axis.spatial(64, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 64) + T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co], p2[n, 0, 0, co], p3[n, h, w, co]) + T.writes(T_relu[n, h, w, co]) + T_relu[n, h, w, co] = T.max(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] + p2[n, 0, 0, co] + p3[n, h, w, co], T.float32(0)) + + +# fmt: on +def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): + anchor_sch = Schedule(anchor_mod) + anchor_trace_fun(anchor_sch) + anchor_trace = anchor_sch.trace + + sch = Schedule(target_mod) + + ms.trace_apply.schedule_using_anchor_trace(sch, anchor_trace, Target(target)) + + tvm.ir.assert_structural_equal(ref, sch.mod) + + +def test_dense_add_cpu(): + def apply_anchor_trace(sch: Schedule) -> None: + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 4, 2] + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[2, 1, 1, 64] + ) + l17, l18, l19, l20 = sch.split( + loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True + ) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[128, 1]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v26 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) + sch.enter_postproc() + b27 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") + b28, b29 = sch.get_child_blocks(b27) + l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28) + l40 = sch.fuse(l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l40) + l41 = sch.fuse(l39, preserve_unit_iters=True) + sch.vectorize(loop=l41) + l42, l43, l44 = sch.get_loops(block=b29) + l45 = sch.fuse(l42, preserve_unit_iters=True) + sch.parallel(loop=l45) + l46 = sch.fuse(l44, preserve_unit_iters=True) + sch.vectorize(loop=l46) + b47 = sch.get_block(name="T_matmul_NT", func_name="main") + l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) + b57 = sch.decompose_reduction(block=b47, loop=l51) + b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") + sch.transform_layout( + block=b58, + buffer=("read", 2), + index_map=tvm.tir.IndexMap.from_func( + lambda i0, i1: ( + floordiv(i0, 64), + i1, + floormod(i0, 64), + ), + inverse_index_map=lambda i0, i1, i2: ( + ((i0 * 64) + i2), + i1, + ), + ), + pad_value=None, + ) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + + verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) + + +def test_dense_add_cpu_no_write_cache(): + def apply_trace(sch): + b0 = sch.get_block(name="T_matmul_NT", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, n=4, max_innermost_factor=64, decision=[4, 4, 4, 2] + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[1, 1, 4, 32] + ) + l17, l18, l19, l20 = sch.split( + loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True + ) + v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[8, 16]) + l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=160) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64) + v25 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25) + sch.enter_postproc() + b26 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.unroll_explicit") + (b27,) = sch.get_child_blocks(b26) + l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) + l38 = sch.fuse(l28, l29, l30, l31, preserve_unit_iters=True) + sch.parallel(loop=l38) + l39 = sch.fuse(l37, preserve_unit_iters=True) + sch.vectorize(loop=l39) + sch.annotate(block_or_loop=l38, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l38, ann_key="pragma_unroll_explicit", ann_val=1) + b40 = sch.get_block(name="T_matmul_NT", func_name="main") + l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b40) + b48 = sch.decompose_reduction(block=b40, loop=l42) + b49 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b50 = sch.cache_read(block=b49, read_buffer_index=2, storage_scope="global") + sch.transform_layout( + block=b49, + buffer=("read", 2), + index_map=tvm.tir.IndexMap.from_func( + lambda i0, i1: ( + floordiv(i1, 16), + floordiv(i0, 32), + floormod(i1, 16), + floormod(i0, 32), + ), + inverse_index_map=lambda i0, i1, i2, i3: ( + ((i1 * 32) + i3), + ((i0 * 16) + i2), + ), + ), + pad_value=None, + ) + sch.annotate(block_or_loop=b50, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + + verify(Dense, apply_trace, DenseAdd, "llvm", DenseAdd_cpu_no_write_cache) + + def test_dense_add_gpu(): def apply_anchor_trace(sch: Schedule) -> None: b0 = sch.get_block(name="T_matmul_NT", func_name="main") @@ -501,7 +1774,971 @@ def apply_anchor_trace(sch: Schedule) -> None: l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) b115 = sch.decompose_reduction(block=b104, loop=l108) - verify(Dense, apply_anchor_trace, DenseAdd, "metal", DenseAdd_scheduled_gpu) + verify(Dense, apply_anchor_trace, DenseAdd, "cuda", DenseAdd_scheduled_gpu) + + +def test_conv2d_int8_tensorcore(): + def apply_trace(sch): + b0 = sch.get_block(name="pad_temp", func_name="main") + b1 = sch.get_block(name="conv2d_nhwc", func_name="main") + b2 = sch.get_block(name="T_subtract", func_name="main") + b3 = sch.get_block(name="T_add", func_name="main") + b4 = sch.get_block(name="T_cast", func_name="main") + b5 = sch.get_block(name="T_multiply", func_name="main") + b6 = sch.get_block(name="T_add_1", func_name="main") + b7 = sch.get_block(name="T_right_shift", func_name="main") + b8 = sch.get_block(name="T_cast_1", func_name="main") + b9 = sch.get_block(name="T_add_2", func_name="main") + b10 = sch.get_block(name="compute", func_name="main") + b11 = sch.get_block(name="T_cast_2", func_name="main") + b12 = sch.get_block(name="T_cast_3", func_name="main") + b13 = sch.get_block(name="T_subtract_1", func_name="main") + b14 = sch.get_block(name="compute_1", func_name="main") + b15 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + b16 = sch.reindex(block=b1, buffer=("write", 0)) + b17 = sch.reindex(block=b1, buffer=("read", 0)) + b18 = sch.reindex(block=b1, buffer=("read", 1)) + sch.transform_layout( + block=b1, + buffer=("read", 0), + index_map=lambda nn, yy, xx, rc: ( + (((nn * 3136) + (yy * 56)) + xx), + rc, + ), + pad_value=None, + ) + sch.transform_layout( + block=b1, + buffer=("read", 1), + index_map=lambda ff, ry, rx, rc: ( + ry, + rx, + ff, + rc, + ), + pad_value=None, + ) + sch.transform_layout( + block=b1, + buffer=("write", 0), + index_map=lambda nn, yy, xx, ff: ( + (((nn * 3136) + (yy * 56)) + xx), + ff, + ), + pad_value=None, + ) + sch.transform_block_layout( + block=b16, + index_map=lambda nn, yy, xx, ff: ( + (((nn * 3136) + (yy * 56)) + xx), + ff, + ), + ) + sch.transform_block_layout( + block=b17, + index_map=lambda nn, yy, xx, rc: ( + (((nn * 3136) + (yy * 56)) + xx), + rc, + ), + ) + sch.transform_block_layout( + block=b18, + index_map=lambda ff, ry, rx, rc: ( + ry, + rx, + ff, + rc, + ), + ) + sch.transform_block_layout( + block=b1, + index_map=lambda nn, yy, xx, ff, ry, rx, rc: ( + ry, + rx, + (((nn * 3136) + (yy * 56)) + xx), + ff, + rc, + ), + ) + l19, l20, l21, l22, l23 = sch.get_loops(block=b1) + l24, l25 = sch.split(loop=l23, factors=[None, 16], preserve_unit_iters=True) + l26, l27 = sch.split(loop=l22, factors=[None, 16], preserve_unit_iters=True) + l28, l29 = sch.split(loop=l21, factors=[None, 16], preserve_unit_iters=True) + l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b1) + sch.reorder(l34, l36, l29, l27, l25) + b38 = sch.blockize(loop=l29) + sch.annotate( + block_or_loop=b38, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_sync_16x16x16_s8s8s32_trans", + ) + sch.annotate( + block_or_loop=b38, + ann_key="meta_schedule.auto_tensorize_init", + ann_val="wmma_fill_16x16x16_s32", + ) + sch.annotate(block_or_loop=b38, ann_key="warp_execution", ann_val=1) + l39, l40, l41, l42, l43 = sch.get_loops(block=b38) + v44, v45, v46 = sch.sample_perfect_tile( + loop=l39, n=3, max_innermost_factor=4, decision=[1, 1, 1] + ) + l47, l48, l49 = sch.split(loop=l39, factors=[v44, v45, v46], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile( + loop=l40, n=3, max_innermost_factor=4, decision=[1, 1, 1] + ) + l53, l54, l55 = sch.split(loop=l40, factors=[v50, v51, v52], preserve_unit_iters=True) + v56, v57, v58, v59, v60 = sch.sample_perfect_tile( + loop=l41, n=5, max_innermost_factor=4, decision=[392, 1, 8, 1, 1] + ) + l61, l62, l63, l64, l65 = sch.split( + loop=l41, factors=[v56, v57, v58, v59, v60], preserve_unit_iters=True + ) + v66, v67, v68, v69, v70 = sch.sample_perfect_tile( + loop=l42, n=5, max_innermost_factor=4, decision=[8, 1, 2, 1, 1] + ) + l71, l72, l73, l74, l75 = sch.split( + loop=l42, factors=[v66, v67, v68, v69, v70], preserve_unit_iters=True + ) + v76, v77, v78 = sch.sample_perfect_tile( + loop=l43, n=3, max_innermost_factor=4, decision=[2, 1, 2] + ) + l79, l80, l81 = sch.split(loop=l43, factors=[v76, v77, v78], preserve_unit_iters=True) + sch.reorder( + l61, + l71, + l62, + l72, + l63, + l73, + l47, + l53, + l79, + l48, + l54, + l80, + l64, + l74, + l49, + l55, + l81, + l65, + l75, + ) + l82 = sch.fuse(l61, l71, preserve_unit_iters=True) + sch.bind(loop=l82, thread_axis="blockIdx.x") + l83 = sch.fuse(l62, l72, preserve_unit_iters=True) + sch.bind(loop=l83, thread_axis="vthread.x") + l84 = sch.fuse(l63, l73, preserve_unit_iters=True) + sch.bind(loop=l84, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b38, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32 + ) + sch.annotate( + block_or_loop=b38, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024 + ) + b85 = sch.cache_write(block=b38, write_buffer_index=0, storage_scope="shared") + sch.reverse_compute_at(block=b85, loop=l83, preserve_unit_loops=True, index=-1) + b86 = sch.cache_write(block=b38, write_buffer_index=0, storage_scope="wmma.accumulator") + sch.reverse_compute_at(block=b86, loop=l84, preserve_unit_loops=True, index=-1) + v87 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=0, + ) + sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) + sch.reverse_compute_inline(block=b16) + l88, l89, l90, l91, l92 = sch.get_loops(block=b86) + l93, l94 = sch.split(loop=l92, factors=[None, 16], preserve_unit_iters=True) + l95, l96 = sch.split(loop=l91, factors=[None, 16], preserve_unit_iters=True) + l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b86) + sch.reorder(l102, l96, l94) + b104 = sch.blockize(loop=l96) + sch.annotate( + block_or_loop=b104, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_store_16x16x16_s32_shared", + ) + b105 = sch.cache_read( + block=b38, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b38] + ) + sch.compute_at(block=b105, loop=l79, preserve_unit_loops=True, index=-1) + l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b105) + l114 = sch.fuse(l112, l113, preserve_unit_iters=True) + v115 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=5, + ) + sch.annotate(block_or_loop=b105, ann_key="meta_schedule.cooperative_fetch", ann_val=v115) + b116 = sch.cache_read( + block=b38, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b38] + ) + sch.compute_at(block=b116, loop=l79, preserve_unit_loops=True, index=-1) + l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b116) + l127 = sch.fuse(l123, l124, l125, l126, preserve_unit_iters=True) + v128 = sch.sample_categorical( + candidates=[1, 2, 3, 4, 8, 16], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=4, + ) + sch.annotate(block_or_loop=b116, ann_key="meta_schedule.cooperative_fetch", ann_val=v128) + b129 = sch.cache_read(block=b38, read_buffer_index=0, storage_scope="wmma.matrix_a") + sch.compute_at(block=b129, loop=l80, preserve_unit_loops=True, index=-1) + l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b129) + l141, l142 = sch.split(loop=l140, factors=[None, 16], preserve_unit_iters=True) + l143, l144 = sch.split(loop=l139, factors=[None, 16], preserve_unit_iters=True) + ( + l145, + l146, + l147, + l148, + l149, + l150, + l151, + l152, + l153, + l154, + l155, + l156, + l157, + ) = sch.get_loops(block=b129) + sch.reorder(l156, l144, l142) + b158 = sch.blockize(loop=l144) + sch.annotate( + block_or_loop=b158, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_load_16x16x16_s8_a", + ) + b159 = sch.cache_read(block=b38, read_buffer_index=1, storage_scope="wmma.matrix_b") + sch.compute_at(block=b159, loop=l80, preserve_unit_loops=True, index=-1) + ( + l160, + l161, + l162, + l163, + l164, + l165, + l166, + l167, + l168, + l169, + l170, + l171, + l172, + ) = sch.get_loops(block=b159) + l173, l174 = sch.split(loop=l172, factors=[None, 16], preserve_unit_iters=True) + l175, l176 = sch.split(loop=l171, factors=[None, 16], preserve_unit_iters=True) + ( + l177, + l178, + l179, + l180, + l181, + l182, + l183, + l184, + l185, + l186, + l187, + l188, + l189, + l190, + l191, + ) = sch.get_loops(block=b159) + sch.reorder(l190, l176, l174) + b192 = sch.blockize(loop=l176) + sch.annotate( + block_or_loop=b192, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_load_16x16x16_s8_b_trans", + ) + sch.compute_inline(block=b17) + sch.compute_inline(block=b18) + sch.storage_align(block=b105, buffer_index=0, axis=-2, factor=32, offset=16) + sch.storage_align(block=b116, buffer_index=0, axis=-2, factor=32, offset=16) + sch.reverse_compute_inline(block=b14) + sch.reverse_compute_inline(block=b13) + sch.reverse_compute_inline(block=b12) + sch.reverse_compute_inline(block=b11) + sch.reverse_compute_inline(block=b10) + sch.reverse_compute_inline(block=b9) + sch.reverse_compute_inline(block=b8) + sch.reverse_compute_inline(block=b7) + sch.reverse_compute_inline(block=b6) + sch.reverse_compute_inline(block=b5) + sch.reverse_compute_inline(block=b4) + sch.reverse_compute_inline(block=b3) + sch.reverse_compute_inline(block=b2) + sch.compute_inline(block=b0) + v193 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=3, + ) + sch.annotate(block_or_loop=b15, ann_key="meta_schedule.unroll_explicit", ann_val=v193) + sch.enter_postproc() + sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch") + l194, l195, l196, l197 = sch.get_loops(block=b85) + l198, l199 = sch.split(loop=l197, factors=[None, 16], preserve_unit_iters=True) + sch.bind(loop=l199, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.cooperative_fetch") + l200, l201, l202, l203, l204, l205, l206 = sch.get_loops(block=b105) + l207, l208, l209 = sch.split(loop=l206, factors=[None, 16, 16], preserve_unit_iters=True) + sch.vectorize(loop=l209) + sch.bind(loop=l208, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b116, ann_key="meta_schedule.cooperative_fetch") + l210, l211, l212, l213, l214, l215, l216 = sch.get_loops(block=b116) + l217, l218, l219 = sch.split(loop=l216, factors=[None, 16, 8], preserve_unit_iters=True) + sch.vectorize(loop=l219) + sch.bind(loop=l218, thread_axis="threadIdx.x") + b220 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b220, ann_key="meta_schedule.unroll_explicit") + b221, b222, b223, b224, b225, b226, b227 = sch.get_child_blocks(b220) + l228, l229, l230, l231, l232, l233, l234, l235, l236 = sch.get_loops(block=b221) + sch.annotate(block_or_loop=l228, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l228, ann_key="pragma_unroll_explicit", ann_val=1) + l237, l238, l239, l240, l241, l242, l243, l244, l245 = sch.get_loops(block=b222) + sch.annotate(block_or_loop=l237, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l237, ann_key="pragma_unroll_explicit", ann_val=1) + l246, l247, l248, l249, l250, l251, l252, l253, l254, l255, l256 = sch.get_loops(block=b223) + sch.annotate(block_or_loop=l246, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l246, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l257, + l258, + l259, + l260, + l261, + l262, + l263, + l264, + l265, + l266, + l267, + l268, + l269, + ) = sch.get_loops(block=b224) + sch.annotate(block_or_loop=l257, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l257, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l270, + l271, + l272, + l273, + l274, + l275, + l276, + l277, + l278, + l279, + l280, + l281, + l282, + l283, + l284, + l285, + ) = sch.get_loops(block=b225) + sch.annotate(block_or_loop=l270, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l270, ann_key="pragma_unroll_explicit", ann_val=1) + l286, l287, l288, l289, l290 = sch.get_loops(block=b226) + sch.annotate(block_or_loop=l286, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l286, ann_key="pragma_unroll_explicit", ann_val=1) + l291, l292, l293, l294, l295 = sch.get_loops(block=b227) + sch.annotate(block_or_loop=l291, ann_key="pragma_auto_unroll_max_step", ann_val=512) + sch.annotate(block_or_loop=l291, ann_key="pragma_unroll_explicit", ann_val=1) + b296 = sch.get_block(name="conv2d_nhwc_o", func_name="main") + ( + l297, + l298, + l299, + l300, + l301, + l302, + l303, + l304, + l305, + l306, + l307, + l308, + l309, + l310, + l311, + l312, + ) = sch.get_loops(block=b296) + b313 = sch.decompose_reduction(block=b296, loop=l302) + sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize") + sch.annotate( + block_or_loop=b313, + ann_key="meta_schedule.auto_tensorize", + ann_val="wmma_fill_16x16x16_s32", + ) + sch.unannotate(block_or_loop=b296, ann_key="meta_schedule.auto_tensorize_init") + sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize_init") + b314 = sch.get_block(name="conv2d_nhwc_o_init", func_name="main") + sch.unannotate(block_or_loop=b314, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b314, tensor_intrin="wmma_fill_16x16x16_s32") + b315 = sch.get_block(name="pad_temp_reindex_shared_wmma.matrix_a_o", func_name="main") + sch.unannotate(block_or_loop=b315, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b315, tensor_intrin="wmma_load_16x16x16_s8_a") + b316 = sch.get_block(name="p1_reindex_shared_wmma.matrix_b_o", func_name="main") + sch.unannotate(block_or_loop=b316, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b316, tensor_intrin="wmma_load_16x16x16_s8_b_trans") + b317 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main") + sch.unannotate(block_or_loop=b317, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b317, tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans") + b318 = sch.get_block(name="conv2d_nhwc_reindex_shared_wmma.accumulator_o", func_name="main") + sch.unannotate(block_or_loop=b318, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b318, tensor_intrin="wmma_store_16x16x16_s32_shared") + + verify(Conv2dInt8, apply_trace, Conv2dInt8_target, "cuda", Conv2dInt8_tensorcore_scheduled) + + +def test_conv2d_int8_vnni(): + def apply_trace(sch): + b0 = sch.get_block(name="compile_engine_const", func_name="main") + b1 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") + b2 = sch.get_block(name="T_add", func_name="main") + b3 = sch.get_block(name="T_cast", func_name="main") + b4 = sch.get_block(name="T_multiply", func_name="main") + b5 = sch.get_block(name="compile_engine_const_1", func_name="main") + b6 = sch.get_block(name="T_add_1", func_name="main") + b7 = sch.get_block(name="T_floor", func_name="main") + b8 = sch.get_block(name="T_cast_1", func_name="main") + b9 = sch.get_block(name="compute", func_name="main") + b10 = sch.get_block(name="T_cast_2", func_name="main") + b11 = sch.get_block(name="T_cast_3", func_name="main") + b12 = sch.get_block(name="T_subtract", func_name="main") + b13 = sch.get_block(name="T_multiply_1", func_name="main") + b14 = sch.get_block(name="compile_engine_const_2", func_name="main") + b15 = sch.get_block(name="T_add_2", func_name="main") + b16 = sch.get_block(name="T_floor_1", func_name="main") + b17 = sch.get_block(name="T_cast_4", func_name="main") + b18 = sch.get_block(name="T_add_3", func_name="main") + b19 = sch.get_block(name="compute_1", func_name="main") + b20 = sch.get_block(name="T_cast_5", func_name="main") + b21 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b20) + sch.compute_inline(block=b19) + sch.compute_inline(block=b18) + sch.compute_inline(block=b17) + sch.compute_inline(block=b16) + sch.compute_inline(block=b15) + sch.compute_inline(block=b14) + sch.compute_inline(block=b13) + sch.compute_inline(block=b12) + sch.compute_inline(block=b11) + sch.compute_inline(block=b10) + sch.compute_inline(block=b9) + sch.compute_inline(block=b8) + sch.compute_inline(block=b7) + sch.compute_inline(block=b6) + sch.compute_inline(block=b5) + sch.compute_inline(block=b4) + sch.compute_inline(block=b3) + sch.compute_inline(block=b2) + sch.compute_inline(block=b0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l22, l23, l24, l25, l26, l27, l28, l29, l30, l31 = sch.get_loops(block=b1) + l32, l33 = sch.split(loop=l31, factors=[None, 4], preserve_unit_iters=True) + l34, l35 = sch.split(loop=l26, factors=[None, 16], preserve_unit_iters=True) + l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b1) + sch.reorder(l42, l43, l44, l45, l46, l35, l33) + b48 = sch.blockize(loop=l35) + sch.annotate( + block_or_loop=b48, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni" + ) + l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b48) + v59, v60, v61, v62 = sch.sample_perfect_tile( + loop=l49, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1] + ) + l63, l64, l65, l66 = sch.split( + loop=l49, factors=[v59, v60, v61, v62], preserve_unit_iters=True + ) + v67, v68, v69, v70 = sch.sample_perfect_tile( + loop=l50, n=4, max_innermost_factor=64, decision=[4, 32, 1, 1] + ) + l71, l72, l73, l74 = sch.split( + loop=l50, factors=[v67, v68, v69, v70], preserve_unit_iters=True + ) + v75, v76, v77, v78 = sch.sample_perfect_tile( + loop=l51, n=4, max_innermost_factor=64, decision=[1, 7, 1, 1] + ) + l79, l80, l81, l82 = sch.split( + loop=l51, factors=[v75, v76, v77, v78], preserve_unit_iters=True + ) + v83, v84, v85, v86 = sch.sample_perfect_tile( + loop=l52, n=4, max_innermost_factor=64, decision=[1, 1, 1, 7] + ) + l87, l88, l89, l90 = sch.split( + loop=l52, factors=[v83, v84, v85, v86], preserve_unit_iters=True + ) + v91, v92, v93, v94 = sch.sample_perfect_tile( + loop=l53, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1] + ) + l95, l96, l97, l98 = sch.split( + loop=l53, factors=[v91, v92, v93, v94], preserve_unit_iters=True + ) + v99, v100 = sch.sample_perfect_tile(loop=l54, n=2, max_innermost_factor=64, decision=[1, 1]) + l101, l102 = sch.split(loop=l54, factors=[v99, v100], preserve_unit_iters=True) + v103, v104 = sch.sample_perfect_tile( + loop=l55, n=2, max_innermost_factor=64, decision=[1, 1] + ) + l105, l106 = sch.split(loop=l55, factors=[v103, v104], preserve_unit_iters=True) + v107, v108 = sch.sample_perfect_tile( + loop=l56, n=2, max_innermost_factor=64, decision=[4, 8] + ) + l109, l110 = sch.split(loop=l56, factors=[v107, v108], preserve_unit_iters=True) + v111, v112 = sch.sample_perfect_tile( + loop=l57, n=2, max_innermost_factor=64, decision=[4, 1] + ) + l113, l114 = sch.split(loop=l57, factors=[v111, v112], preserve_unit_iters=True) + v115, v116 = sch.sample_perfect_tile( + loop=l58, n=2, max_innermost_factor=64, decision=[1, 1] + ) + l117, l118 = sch.split(loop=l58, factors=[v115, v116], preserve_unit_iters=True) + sch.reorder( + l63, + l71, + l79, + l87, + l95, + l64, + l72, + l80, + l88, + l96, + l101, + l105, + l109, + l113, + l117, + l65, + l73, + l81, + l89, + l97, + l102, + l106, + l110, + l114, + l118, + l66, + l74, + l82, + l90, + l98, + ) + (b119,) = sch.get_consumers(block=b48) + sch.reverse_compute_at(block=b119, loop=l96, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.parallel", ann_val=96) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.vectorize", ann_val=64) + v120 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2 + ) + sch.annotate(block_or_loop=b21, ann_key="meta_schedule.unroll_explicit", ann_val=v120) + sch.enter_postproc() + b121 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.parallel") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.vectorize") + sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.unroll_explicit") + b122, b123 = sch.get_child_blocks(b121) + ( + l124, + l125, + l126, + l127, + l128, + l129, + l130, + l131, + l132, + l133, + l134, + l135, + l136, + l137, + l138, + l139, + l140, + l141, + l142, + l143, + l144, + l145, + l146, + l147, + l148, + l149, + l150, + l151, + l152, + l153, + ) = sch.get_loops(block=b122) + l154 = sch.fuse(l124, l125, l126, l127, l128, l129, l130, preserve_unit_iters=True) + sch.parallel(loop=l154) + sch.annotate(block_or_loop=l154, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l154, ann_key="pragma_unroll_explicit", ann_val=1) + l155, l156, l157, l158, l159, l160, l161, l162, l163 = sch.get_loops(block=b123) + l164 = sch.fuse(l163, preserve_unit_iters=True) + sch.vectorize(loop=l164) + sch.annotate(block_or_loop=l155, ann_key="pragma_auto_unroll_max_step", ann_val=64) + sch.annotate(block_or_loop=l155, ann_key="pragma_unroll_explicit", ann_val=1) + b165 = sch.get_block(name="conv2d_NCHWc_int8_o", func_name="main") + ( + l166, + l167, + l168, + l169, + l170, + l171, + l172, + l173, + l174, + l175, + l176, + l177, + l178, + l179, + l180, + l181, + l182, + l183, + l184, + l185, + l186, + l187, + l188, + l189, + ) = sch.get_loops(block=b165) + b190 = sch.decompose_reduction(block=b165, loop=l172) + sch.unannotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize") + sch.annotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize", ann_val="") + b191 = sch.get_block(name="conv2d_NCHWc_int8_o_init", func_name="main") + sch.unannotate(block_or_loop=b191, ann_key="meta_schedule.auto_tensorize") + (b192,) = sch.get_child_blocks(b191) + (l193,) = sch.get_loops(block=b192) + sch.vectorize(loop=l193) + b194 = sch.get_block(name="conv2d_NCHWc_int8_o_update", func_name="main") + sch.unannotate(block_or_loop=b194, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b194, tensor_intrin="dot_16x4_vnni") + + vnni_id = llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512") + verify( + Conv2dInt8_NCHWc, + apply_trace, + Conv2dInt8_NCHWc_target, + "llvm -mcpu=cascadelake", + get_conv2d_vnni_mod(vnni_id), + ) + + +def test_winograd_gpu(): + def apply_trace(sch): + b0 = sch.get_block(name="B", func_name="main") + b1 = sch.get_block(name="data_pack", func_name="main") + b2 = sch.get_block(name="bgemm", func_name="main") + b3 = sch.get_block(name="A", func_name="main") + b4 = sch.get_block(name="inverse", func_name="main") + b5 = sch.get_block(name="conv2d_winograd", func_name="main") + b6 = sch.get_block(name="T_add", func_name="main") + b7 = sch.get_block(name="T_relu", func_name="main") + b8 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b0) + (b9,) = sch.get_producers(block=b1) + (b10,) = sch.get_producers(block=b9) + l11, l12, l13, l14, l15, l16 = sch.get_loops(block=b1) + v17, v18 = sch.sample_perfect_tile( + loop=l13, n=2, max_innermost_factor=64, decision=[14, 14] + ) + l19, l20 = sch.split(loop=l13, factors=[v17, v18], preserve_unit_iters=True) + v21, v22 = sch.sample_perfect_tile(loop=l14, n=2, max_innermost_factor=64, decision=[8, 8]) + l23, l24 = sch.split(loop=l14, factors=[v21, v22], preserve_unit_iters=True) + sch.unroll(loop=l11) + sch.unroll(loop=l12) + sch.unroll(loop=l15) + sch.unroll(loop=l16) + sch.reorder(l19, l23, l20, l24, l11, l12, l15, l16) + sch.compute_at(block=b9, loop=l24, preserve_unit_loops=True, index=-1) + sch.set_scope(block=b9, buffer_index=0, storage_scope="local") + sch.compute_inline(block=b10) + l25, l26, l27, l28, l29, l30, l31, l32 = sch.get_loops(block=b1) + l33 = sch.fuse(l25, l26, l27, l28, preserve_unit_iters=True) + v34 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l35, l36 = sch.split(loop=l33, factors=[None, v34], preserve_unit_iters=True) + sch.bind(loop=l35, thread_axis="blockIdx.x") + sch.bind(loop=l36, thread_axis="threadIdx.x") + sch.compute_inline(block=b3) + l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b4) + v43, v44 = sch.sample_perfect_tile(loop=l39, n=2, max_innermost_factor=64, decision=[28, 7]) + l45, l46 = sch.split(loop=l39, factors=[v43, v44], preserve_unit_iters=True) + v47, v48 = sch.sample_perfect_tile(loop=l40, n=2, max_innermost_factor=64, decision=[2, 32]) + l49, l50 = sch.split(loop=l40, factors=[v47, v48], preserve_unit_iters=True) + sch.unroll(loop=l37) + sch.unroll(loop=l38) + sch.unroll(loop=l41) + sch.unroll(loop=l42) + sch.reorder(l45, l49, l46, l50, l37, l38, l41, l42) + l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b4) + l59 = sch.fuse(l51, l52, l53, l54, preserve_unit_iters=True) + v60 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=4, + ) + l61, l62 = sch.split(loop=l59, factors=[None, v60], preserve_unit_iters=True) + sch.bind(loop=l61, thread_axis="blockIdx.x") + sch.bind(loop=l62, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + l63, l64, l65, l66, l67 = sch.get_loops(block=b2) + v68, v69, v70, v71, v72 = sch.sample_perfect_tile( + loop=l63, n=5, max_innermost_factor=64, decision=[1, 2, 3, 1, 1] + ) + l73, l74, l75, l76, l77 = sch.split( + loop=l63, factors=[v68, v69, v70, v71, v72], preserve_unit_iters=True + ) + v78, v79, v80, v81, v82 = sch.sample_perfect_tile( + loop=l64, n=5, max_innermost_factor=64, decision=[6, 1, 1, 1, 1] + ) + l83, l84, l85, l86, l87 = sch.split( + loop=l64, factors=[v78, v79, v80, v81, v82], preserve_unit_iters=True + ) + v88, v89, v90, v91, v92 = sch.sample_perfect_tile( + loop=l65, n=5, max_innermost_factor=64, decision=[7, 2, 1, 14, 1] + ) + l93, l94, l95, l96, l97 = sch.split( + loop=l65, factors=[v88, v89, v90, v91, v92], preserve_unit_iters=True + ) + v98, v99, v100, v101, v102 = sch.sample_perfect_tile( + loop=l66, n=5, max_innermost_factor=64, decision=[4, 1, 16, 1, 1] + ) + l103, l104, l105, l106, l107 = sch.split( + loop=l66, factors=[v98, v99, v100, v101, v102], preserve_unit_iters=True + ) + v108, v109, v110 = sch.sample_perfect_tile( + loop=l67, n=3, max_innermost_factor=64, decision=[2, 2, 16] + ) + l111, l112, l113 = sch.split(loop=l67, factors=[v108, v109, v110], preserve_unit_iters=True) + sch.reorder( + l73, + l83, + l93, + l103, + l74, + l84, + l94, + l104, + l75, + l85, + l95, + l105, + l111, + l112, + l76, + l86, + l96, + l106, + l113, + l77, + l87, + l97, + l107, + ) + l114 = sch.fuse(l73, l83, l93, l103, preserve_unit_iters=True) + sch.bind(loop=l114, thread_axis="blockIdx.x") + l115 = sch.fuse(l74, l84, l94, l104, preserve_unit_iters=True) + sch.bind(loop=l115, thread_axis="vthread.x") + l116 = sch.fuse(l75, l85, l95, l105, preserve_unit_iters=True) + sch.bind(loop=l116, thread_axis="threadIdx.x") + sch.annotate( + block_or_loop=b2, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32 + ) + sch.annotate( + block_or_loop=b2, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024 + ) + b117 = sch.cache_write(block=b2, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b117, loop=l116, preserve_unit_loops=True, index=-1) + b118 = sch.cache_read( + block=b2, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b2] + ) + sch.compute_at(block=b118, loop=l111, preserve_unit_loops=True, index=-1) + l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b118) + l127 = sch.fuse(l123, l124, l125, l126, preserve_unit_iters=True) + v128 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b118, ann_key="meta_schedule.cooperative_fetch", ann_val=v128) + b129 = sch.cache_read( + block=b2, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b2] + ) + sch.compute_at(block=b129, loop=l111, preserve_unit_loops=True, index=-1) + l130, l131, l132, l133, l134, l135, l136, l137 = sch.get_loops(block=b129) + l138 = sch.fuse(l134, l135, l136, l137, preserve_unit_iters=True) + v139 = sch.sample_categorical( + candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3 + ) + sch.annotate(block_or_loop=b129, ann_key="meta_schedule.cooperative_fetch", ann_val=v139) + sch.reverse_compute_inline(block=b7) + sch.reverse_compute_inline(block=b6) + v140 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=4, + ) + sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140) + l141, l142, l143, l144 = sch.get_loops(block=b5) + l145 = sch.fuse(l141, l142, l143, l144, preserve_unit_iters=True) + v146 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l147, l148 = sch.split(loop=l145, factors=[None, v146], preserve_unit_iters=True) + sch.bind(loop=l147, thread_axis="blockIdx.x") + sch.bind(loop=l148, thread_axis="threadIdx.x") + sch.enter_postproc() + sch.unannotate(block_or_loop=b118, ann_key="meta_schedule.cooperative_fetch") + l149, l150, l151, l152, l153 = sch.get_loops(block=b118) + l154, l155, l156 = sch.split(loop=l153, factors=[None, 48, 4], preserve_unit_iters=True) + sch.vectorize(loop=l156) + sch.bind(loop=l155, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b129, ann_key="meta_schedule.cooperative_fetch") + l157, l158, l159, l160, l161 = sch.get_loops(block=b129) + l162, l163, l164 = sch.split(loop=l161, factors=[None, 48, 4], preserve_unit_iters=True) + sch.vectorize(loop=l164) + sch.bind(loop=l163, thread_axis="threadIdx.x") + b165 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b165, ann_key="meta_schedule.unroll_explicit") + b166, b167, b168, b169, b170, b171, b172, b173 = sch.get_child_blocks(b165) + l174, l175, l176, l177, l178, l179 = sch.get_loops(block=b166) + sch.annotate(block_or_loop=l174, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l174, ann_key="pragma_unroll_explicit", ann_val=1) + l180, l181, l182, l183, l184, l185 = sch.get_loops(block=b167) + sch.annotate(block_or_loop=l180, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l180, ann_key="pragma_unroll_explicit", ann_val=1) + l186, l187, l188, l189, l190, l191, l192 = sch.get_loops(block=b168) + sch.annotate(block_or_loop=l186, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l186, ann_key="pragma_unroll_explicit", ann_val=1) + l193, l194, l195, l196, l197, l198, l199 = sch.get_loops(block=b169) + sch.annotate(block_or_loop=l193, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l193, ann_key="pragma_unroll_explicit", ann_val=1) + ( + l200, + l201, + l202, + l203, + l204, + l205, + l206, + l207, + l208, + l209, + l210, + l211, + l212, + l213, + ) = sch.get_loops(block=b170) + sch.annotate(block_or_loop=l200, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l200, ann_key="pragma_unroll_explicit", ann_val=1) + l214, l215, l216, l217, l218, l219, l220 = sch.get_loops(block=b171) + sch.annotate(block_or_loop=l214, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l214, ann_key="pragma_unroll_explicit", ann_val=1) + l221, l222, l223, l224, l225, l226 = sch.get_loops(block=b172) + sch.annotate(block_or_loop=l221, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l221, ann_key="pragma_unroll_explicit", ann_val=1) + l227, l228 = sch.get_loops(block=b173) + sch.annotate(block_or_loop=l227, ann_key="pragma_auto_unroll_max_step", ann_val=1024) + sch.annotate(block_or_loop=l227, ann_key="pragma_unroll_explicit", ann_val=1) + b229 = sch.get_block(name="data_pack", func_name="main") + l230, l231, l232, l233, l234, l235 = sch.get_loops(block=b229) + b236 = sch.decompose_reduction(block=b229, loop=l234) + b237 = sch.get_block(name="bgemm", func_name="main") + ( + l238, + l239, + l240, + l241, + l242, + l243, + l244, + l245, + l246, + l247, + l248, + l249, + l250, + l251, + ) = sch.get_loops(block=b237) + b252 = sch.decompose_reduction(block=b237, loop=l241) + b253 = sch.get_block(name="inverse", func_name="main") + l254, l255, l256, l257, l258, l259 = sch.get_loops(block=b253) + b260 = sch.decompose_reduction(block=b253, loop=l258) + + verify( + Conv2dWinogradAddRelu, + apply_trace, + Conv2dWinogradAddResidualRelu, + "cuda", + Conv2dWinogradAddResidualRelu_scheduled, + ) if __name__ == "__main__": From 0423ceceb568ac89cea9e8d9962a79a12ac8ea4d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 17:34:43 +0900 Subject: [PATCH 19/28] add doc --- python/tvm/meta_schedule/trace_apply.py | 22 ++++++++++-- python/tvm/tir/schedule/analysis.py | 16 ++++++++- src/meta_schedule/module_equality.cc | 2 ++ src/meta_schedule/trace_apply.cc | 20 ++++++++++- src/relay/backend/task_extraction.cc | 8 +++++ src/tir/schedule/utils.h | 4 +++ .../test_meta_schedule_relay_integration.py | 35 +++++++++++++++++-- 7 files changed, 100 insertions(+), 7 deletions(-) diff --git a/python/tvm/meta_schedule/trace_apply.py b/python/tvm/meta_schedule/trace_apply.py index 364a84344684..351173b405d2 100644 --- a/python/tvm/meta_schedule/trace_apply.py +++ b/python/tvm/meta_schedule/trace_apply.py @@ -14,9 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TODO""" +"""Specialized applications of trace""" +from ..tir.schedule import Schedule, Trace from . import _ffi_api -def schedule_using_anchor_trace(sch, anchor_trace, target): - return _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore # pylint: disable=no-member +def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target) -> None: + """Apply the trace from a TIR module whose anchor block is the same but fused elemewise op + blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add + subgraph to other subgraphs having the same conv2d workload, for example. We call such trace + an "anchor trace". Those blocks that are not scheduled by the given anchor trace will be either + inlined or parallelized. + + Parameters + ---------- + sch : Schedule + The target schedule + anchor_trace: Trace + The trace generated for other TIR module having the same anchor block + target : tvm.target.Target + The compilation target + """ + _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 5a4e5840ead0..52aafbaf067d 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -124,5 +124,19 @@ def get_auto_tensorize_mapping_info( return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore -def has_block(sch, block_name): +def has_block(sch: Schedule, block_name: str) -> bool: + """Query if the given block name exists in the module associated with the provided schedule. + + Parameters + ---------- + sch : Schedule + The schedule + block_name : str + The name of the block to query + + Returns + ------- + yes/no: bool + True if the given block exists in the schedule. + """ return _ffi_api.HasBlock(sch, block_name) diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 9340b373d5b4..f9ffe82aa271 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -74,6 +74,8 @@ class ModuleEqualityIgnoreNDArray : public ModuleEquality { } }; +// The NDArray-ignoring variant of structural equal / hash is used for the module equality +// on the extracted anchor blocks. class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index a1a8ec61c6da..fb8eb0294aac 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -33,6 +33,7 @@ namespace meta_schedule { using namespace tir; +// Returns true if b1 is an ancestor of b2 bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { return true; @@ -45,6 +46,8 @@ bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { static auto kind_get_block = InstructionKind::Get("GetBlock"); + // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace. + // We record such block names to avoid inlining them here. std::unordered_set get_block_names; for (const auto& inst : anchor_trace->insts) { if (inst->kind.same_as(kind_get_block)) { @@ -65,12 +68,15 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { for (auto name : GetBlockNames(sch->mod())) { auto block = sch->GetBlock(name); if (anchor_block_rv && IsAncestor(block, *anchor_block_rv, sch)) continue; + // Spatial blocks which are not referenced in the anchor trace will be inlined here. if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { inline_rule->Apply(sch, block); } } } +// Apply instructions from the anchor trace to the target schedule, and returns blocks +// that remain unscheduled. std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); static auto kind_get_block = InstructionKind::Get("GetBlock"); @@ -81,9 +87,12 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { const auto sch_orig = sch->Copy(); std::unordered_map rv_map; + // Blocks and loops that appear in the anchor trace but are not part of the target schedule. std::unordered_set foreign_blocks; std::unordered_set foreign_loops; + // Instructions in the anchor trace can be applied only if all inputs are part of the target + // schedule. auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { for (auto input : inst->inputs) { if (!input.defined()) continue; @@ -97,6 +106,8 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { for (const auto& inst : anchor_trace->insts) { if (!is_inst_applicable(inst)) { + // If we find an instruction that is not applicable, its outputs are recorded as "foreign" + // to the target schedule. for (auto output : inst->outputs) { if (output->IsInstance()) { foreign_blocks.insert(Downcast(output)); @@ -110,10 +121,14 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { Array inputs = TranslateInputRVs(inst->inputs, rv_map); if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast(inst->attrs[0]))) { + // The anchor trace does get_block on a block that is not part of the target schedule. auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); continue; } else if (inst->kind.same_as(kind_reverse_compute_inline)) { + // The anchor trace does reverse_compute_inline on a block, but the block with the same name + // in the target schedule cannot be reverse compute inline-ed. + // In such cases, it should be possible to apply compute_inline instead. auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); if (!CanReverseComputeInline(sch->state(), block_sref)) { @@ -122,6 +137,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { continue; } } else if (inst->kind.same_as(kind_compute_inline)) { + // Similar to the reverse_compute_inline case above. auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); if (!CanComputeInline(sch->state(), block_sref)) { @@ -182,7 +198,9 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm InlinePostBlocks(sch, anchor_trace, target); auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); - ICHECK(unscheduled_blocks.size() <= 1); + ICHECK(unscheduled_blocks.size() <= 1) + << "All blocks should have been scheduled or only one (fused) spatial block can remain " + "unscheduled at this point."; if (unscheduled_blocks.empty()) { // All blocks have already been scheduled. diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 1dd1ed3ca701..7e66dafe16f5 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -100,6 +100,14 @@ Array ExtractTask(IRModule mod, Target target, op_counts[i] = OpCounter::GetOpCount(std::get<1>(lower_results[i])); } + // When anchor-block based equality is used, tuning tasks "nn_conv2d_add_nn_relu" and + // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. Thus, one of + // them will be filtered by the cache below. + // + // To make sure that we tune "nn_conv2d_add_nn_relu" and not "nn_conv2d_add_add_nn_relu", + // we sort the TE lowering results based on the number of relay ops. This way, + // "nn_conv2d_add_nn_relu" will be added to the cache first, and "nn_conv2d_add_add_nn_relu" + // will be filtered. std::sort(indices.begin(), indices.end(), [&op_counts](int i1, int i2) { return op_counts[i1] < op_counts[i2]; }); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index e7363bd20a34..eafaf096b0eb 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -442,6 +442,9 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { } } +/******** Utilities for retrieving information about blocks ********/ + +/*! \brief Returns the names of the blocks in the provided module. */ inline std::unordered_set GetBlockNames(const IRModule& mod) { struct BlockNameCollector : public tir::StmtVisitor { void VisitStmt_(const tir::BlockNode* block) override { @@ -457,6 +460,7 @@ inline std::unordered_set GetBlockNames(const IRModule& mod) { return collector.block_names; } +/*! \brief Query if the given block name exists in the module associated with the schedule */ inline bool HasBlock(const Schedule& sch, const std::string& block_name) { auto block_names = GetBlockNames(sch->mod()); return block_names.count(block_name); diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 1818177eada9..aae0fa8a3f74 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -106,10 +106,40 @@ def test_meta_schedule_integration_extract_from_resnet(): for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name + +@requires_torch +def test_task_extraction_anchor_block(): + mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.relay_integration.extract_tasks( mod, target="llvm", params=params, module_equality="anchor-block" ) - assert len(extracted_tasks) == 16 + + # Note that there is no task from residual blocks + expected_task_names = [ + "fused_" + s + for s in [ + "nn_max_pool2d", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + "layout_transform", + "layout_transform_reshape_squeeze", + ] + ] + + assert len(extracted_tasks) == len(expected_task_names) + for t in extracted_tasks: + assert t.task_name in expected_task_names, t.task_name @requires_torch @@ -679,4 +709,5 @@ def test_module_equality_ignore_ndarray(): if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_task_extraction_anchor_block() From abb2d0bdfadb872c537e6221e38189228c0b7818 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Oct 2022 20:01:44 +0900 Subject: [PATCH 20/28] use anchor tuning in hexagon int8 tuning test --- .../test_hexagon/metaschedule_e2e/test_resnet50_int8.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index a541c25f3cbc..addbb052a2da 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -116,9 +116,11 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): postprocs=postprocs, mutator_probs={}, ), - # Without this, the same workloads with different constant weights - # are treated as distinct tuning tasks. - module_equality="ignore-ndarray", + # This enables anchor-block tuning, where different subgraphs + # with the same anchor block workload will be identified as equal. + # It reduces the number of conv2d tuning tasks in the int8 resnet50 model + # from 36 to 23, with negligible performance difference. + module_equality="anchor-block", ) return ms.relay_integration.compile_relay( From 1e7db843d194d1b1c01ecaae9a6f6e79e0686398 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 06:05:52 +0900 Subject: [PATCH 21/28] cpplint --- src/meta_schedule/utils.h | 1 + src/tir/schedule/utils.h | 2 ++ tests/python/unittest/test_meta_schedule_relay_integration.py | 3 +-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 949588168534..7240fa418839 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -44,6 +44,7 @@ #include #include +#include #include #include diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index eafaf096b0eb..bcc8b7facbc9 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -31,7 +31,9 @@ #include #include +#include #include +#include #include #include "../../arith/pattern_match.h" diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index aae0fa8a3f74..5c48aad7d58d 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -709,5 +709,4 @@ def test_module_equality_ignore_ndarray(): if __name__ == "__main__": - # tvm.testing.main() - test_task_extraction_anchor_block() + tvm.testing.main() From c84bf2172a762f428308b011351a7358c13c5e51 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 06:24:56 +0900 Subject: [PATCH 22/28] suppress mypy on ffi --- python/tvm/meta_schedule/trace_apply.py | 2 +- python/tvm/tir/schedule/analysis.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/trace_apply.py b/python/tvm/meta_schedule/trace_apply.py index 351173b405d2..3cea6c5b66fe 100644 --- a/python/tvm/meta_schedule/trace_apply.py +++ b/python/tvm/meta_schedule/trace_apply.py @@ -35,4 +35,4 @@ def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target) -> N target : tvm.target.Target The compilation target """ - _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleUsingAnchorTrace(sch, anchor_trace, target) # type: ignore diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 52aafbaf067d..e1c0019d9bf0 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -139,4 +139,4 @@ def has_block(sch: Schedule, block_name: str) -> bool: yes/no: bool True if the given block exists in the schedule. """ - return _ffi_api.HasBlock(sch, block_name) + return _ffi_api.HasBlock(sch, block_name) # type: ignore From 346d55fbfce50e2961898b84729a0170a0702613 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 07:53:22 +0900 Subject: [PATCH 23/28] add workaround for false positive maybe-uninitialized warning --- src/meta_schedule/trace_apply.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index fb8eb0294aac..c95cb2072720 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -58,16 +58,15 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { } auto anchor_block = FindAnchorBlock(sch->mod()); - std::optional anchor_block_rv = std::nullopt; - if (anchor_block) { - anchor_block_rv = sch->GetBlock(anchor_block->name_hint); - } auto inline_rule = GetDefaultAutoInline(target->kind->name); for (auto name : GetBlockNames(sch->mod())) { auto block = sch->GetBlock(name); - if (anchor_block_rv && IsAncestor(block, *anchor_block_rv, sch)) continue; + if (anchor_block) { + auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint); + if (IsAncestor(block, anchor_block_rv, sch)) continue; + } // Spatial blocks which are not referenced in the anchor trace will be inlined here. if (IsSpatial(sch->GetSRef(block)) && !get_block_names.count(name)) { inline_rule->Apply(sch, block); From b88e63b63bdd6a6819c99321898ebc9223022618 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 08:04:05 +0900 Subject: [PATCH 24/28] add a minimal anchor tuning test --- .../test_meta_schedule_relay_integration.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 5c48aad7d58d..5250092221ef 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -708,5 +708,68 @@ def test_module_equality_ignore_ndarray(): np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) +def _test_anchor_tuning(target): + data_shape = (128, 128) + weight_shape1 = (128, 128) + weight_shape2 = (128, 128) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32") + weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32") + dense1 = relay.nn.dense(data, weight1) + dense2 = relay.nn.dense(dense1 + relay.const(1.0, dtype="float32"), weight2) + mod = tvm.IRModule.from_expr(dense2 - data + relay.const(1.0, dtype="float32")) + + weight1_np = np.random.randn(*weight_shape1).astype("float32") + weight2_np = np.random.randn(*weight_shape2).astype("float32") + + data_np = np.random.randn(*data_shape).astype("float32") + params = {"weight1": weight1_np, "weight2": weight2_np} + + module_equality = "anchor-block" + + extracted_tasks = ms.relay_integration.extract_tasks( + mod, target, params, module_equality=module_equality + ) + + assert len(extracted_tasks) == 1 + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + module_equality=module_equality, + ) + lib = ms.relay_integration.compile_relay(database, mod, target, params) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + np.testing.assert_allclose(ref, out, rtol=1e-5, atol=1e-5) + + +def test_anchor_tuning_cpu(): + _test_anchor_tuning("llvm --num-cores=4") + + +@tvm.testing.requires_gpu +def test_anchor_tuning_gpu(): + _test_anchor_tuning("nvidia/geforce-rtx-3070") + + if __name__ == "__main__": tvm.testing.main() From 2f379003f9b1862b070f93e8482f7af03b4ea40c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 10:16:49 +0900 Subject: [PATCH 25/28] relax tol for i386, remove gpu test since it requires sm86 --- .../unittest/test_meta_schedule_relay_integration.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 5250092221ef..f18cedfb5a3b 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -759,17 +759,12 @@ def _test_anchor_tuning(target): .numpy() ) - np.testing.assert_allclose(ref, out, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref, out, atol=1e-3) def test_anchor_tuning_cpu(): _test_anchor_tuning("llvm --num-cores=4") -@tvm.testing.requires_gpu -def test_anchor_tuning_gpu(): - _test_anchor_tuning("nvidia/geforce-rtx-3070") - - if __name__ == "__main__": tvm.testing.main() From 5e2367a8f5e38b76036baffc62cd43d6256f8745 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 Oct 2022 10:43:36 +0900 Subject: [PATCH 26/28] add doc for "anchor-block" module equality --- include/tvm/meta_schedule/database.h | 8 ++++++++ python/tvm/meta_schedule/database/json_database.py | 4 ++++ python/tvm/meta_schedule/database/memory_database.py | 4 ++++ python/tvm/meta_schedule/database/schedule_fn_database.py | 4 ++++ python/tvm/meta_schedule/relay_integration.py | 8 ++++++++ python/tvm/meta_schedule/tune.py | 4 ++++ src/meta_schedule/module_equality.h | 4 ++++ 7 files changed, 36 insertions(+) diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 9eead8d5ec31..a1dd4a412eec 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -183,6 +183,10 @@ class DatabaseNode : public runtime::Object { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. */ explicit DatabaseNode(String mod_eq_name = "structural"); @@ -274,6 +278,10 @@ class PyDatabaseNode : public DatabaseNode { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. */ explicit PyDatabaseNode(String mod_eq_name = "structural"); diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index f81d8913c18a..102a13b90d98 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -40,6 +40,10 @@ class JSONDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ path_workload: str diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 96b9bb5a0112..34a6a141970a 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -33,6 +33,10 @@ class MemoryDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 7a0b433996c5..c7d175cb79d3 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -39,6 +39,10 @@ class ScheduleFnDatabase(Database): - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 089f6e412e20..5e77181d32bf 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -143,6 +143,10 @@ def extract_tasks( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- @@ -288,6 +292,10 @@ def tune_relay( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 07021eac3998..a69c8f126272 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -76,6 +76,10 @@ def tune_tasks( - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. + - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + given module. The "ignore-ndarray" varint is used for the extracted + blocks or in case no anchor block is found. + For the definition of the anchor block, see tir/analysis/analysis.py. Returns ------- diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index 8c99b563551b..ba5877471e2c 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -42,6 +42,10 @@ class ModuleEquality { * - "structural": Use StructuralEqual/Hash * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during * equality testing and hashing. + * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a + * given module. The "ignore-ndarray" varint is used for the extracted blocks + * or in case no anchor block is found. + * For the definition of the anchor block, see tvm/tir/analysis.h. * \return An owning pointer to the created instance */ static std::unique_ptr Create(const std::string& mod_eq_name); From d6893af5bf3337cb1f44dd784d77a7d9937f0980 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 Oct 2022 16:50:33 +0900 Subject: [PATCH 27/28] address comments --- python/tvm/meta_schedule/trace_apply.py | 3 ++- src/meta_schedule/trace_apply.cc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/trace_apply.py b/python/tvm/meta_schedule/trace_apply.py index 3cea6c5b66fe..c621cf973af2 100644 --- a/python/tvm/meta_schedule/trace_apply.py +++ b/python/tvm/meta_schedule/trace_apply.py @@ -16,10 +16,11 @@ # under the License. """Specialized applications of trace""" from ..tir.schedule import Schedule, Trace +from ..target import Target from . import _ffi_api -def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target) -> None: +def schedule_using_anchor_trace(sch: Schedule, anchor_trace: Trace, target: Target) -> None: """Apply the trace from a TIR module whose anchor block is the same but fused elemewise op blocks differ. This function can be used for transferring a trace tuned on a conv2d -> add subgraph to other subgraphs having the same conv2d workload, for example. We call such trace diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index c95cb2072720..70b6451d3546 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -44,6 +44,7 @@ bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { return false; } +// Inline or reverse inline spatial blocks after the anchor block void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { static auto kind_get_block = InstructionKind::Get("GetBlock"); // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace. From 1fe554de1aeadb159f61936174f93a9c2c3b753f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 Oct 2022 16:59:54 +0900 Subject: [PATCH 28/28] add test for cache_write + AllocateConst bug --- .../test_meta_schedule_relay_integration.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index f18cedfb5a3b..c689a15c56b2 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -766,5 +766,57 @@ def test_anchor_tuning_cpu(): _test_anchor_tuning("llvm --num-cores=4") +def test_anchor_tuning_cpu_link_params(): + data_shape = (128, 128) + weight_shape1 = (128, 128) + weight_shape2 = (128, 128) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32") + weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32") + dense1 = relay.nn.dense(data, weight1) + dense2 = relay.nn.dense(dense1, weight2) + mod = tvm.IRModule.from_expr(dense2 + relay.const(1.0, dtype="float32")) + + weight1_np = np.random.randn(*weight_shape1).astype("float32") + weight2_np = np.random.randn(*weight_shape2).astype("float32") + + data_np = np.random.randn(*data_shape).astype("float32") + params = {"weight1": weight1_np, "weight2": weight2_np} + + module_equality = "anchor-block" + target = "llvm --num-cores=4" + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + module_equality=module_equality, + ) + lib = ms.relay_integration.compile_relay(database, mod, target, params) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + np.testing.assert_allclose(ref, out, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main()