From 7f7c50e6b903f845d72c13135c43e84751c8af0a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 Aug 2022 16:12:45 -0500 Subject: [PATCH] [TIR] More hygenic TVM_SREF macros Previously, the `TVM_SREF_TO_BLOCK`, `TVM_SREF_TO_FOR`, and `TVM_TYPE_AS` macros required both the input and output variables. The input variable name is useful for improving the error message returned, but the output variable name isn't necessary for this functionality, and prevents the macro from being used as part of an expression. * Generate an immediately-invoked lambda expression to allow for an independently-scoped `result` variable. * Use parentheses around the input argument, in case the sref is the result of an expression. * Update all call sites to remove the macro argument providing the first argument. --- src/meta_schedule/mutator/mutate_parallel.cc | 4 +- .../mutator/mutate_thread_binding.cc | 8 +-- src/meta_schedule/mutator/mutate_tile_size.cc | 4 +- src/meta_schedule/mutator/mutate_unroll.cc | 4 +- .../rewrite_parallel_vectorize_unroll.cc | 4 +- src/meta_schedule/schedule_rule/auto_bind.cc | 2 +- .../schedule_rule/auto_inline.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 2 +- .../multi_level_tiling_tensor_core.cc | 4 +- .../schedule_rule/random_compute_location.cc | 2 +- src/meta_schedule/utils.h | 2 +- src/tir/schedule/analysis/analysis.cc | 48 ++++++++--------- src/tir/schedule/block_scope.cc | 2 +- src/tir/schedule/concrete_schedule.cc | 4 +- src/tir/schedule/concrete_schedule.h | 6 +-- src/tir/schedule/primitive/block_annotate.cc | 6 +-- .../schedule/primitive/blockize_tensorize.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 14 ++--- src/tir/schedule/primitive/compute_at.cc | 12 ++--- src/tir/schedule/primitive/compute_inline.cc | 8 +-- .../schedule/primitive/decompose_padding.cc | 2 +- src/tir/schedule/primitive/for_kind.cc | 4 +- src/tir/schedule/primitive/get_block_loop.cc | 2 +- .../primitive/layout_transformation.cc | 10 ++-- .../schedule/primitive/loop_transformation.cc | 10 ++-- src/tir/schedule/primitive/reduction.cc | 12 ++--- src/tir/schedule/primitive/sampling.cc | 2 +- src/tir/schedule/state.cc | 14 ++--- src/tir/schedule/transform.cc | 6 +-- src/tir/schedule/utils.h | 51 ++++++++++++------- 30 files changed, 133 insertions(+), 120 deletions(-) diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 5b7fe7f5148d..82b91da682c6 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -64,7 +64,7 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { return nullptr; } ICHECK_EQ(inst->outputs.size(), 1); - const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode); return block; } @@ -82,7 +82,7 @@ std::vector> AnalyzeParallel(const ScheduleState& self, Array block_srefs = tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); std::vector> results; results.reserve(info.realizes.size()); diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 41207162ee1d..de780b53e2d9 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -109,12 +109,12 @@ std::vector MutateThreadBindingNode::FindCan for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); - const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (is_split_by_sample(inst)) { CHECK_EQ(inst->outputs.size(), 2); // Only consider the inner loop, which can be bound to threadIdx.x - const tir::LoopRVNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[1], tir::LoopRVNode); + const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode); sampled_split_insts[var_rv] = inst.get(); } else if (is_thread_binding_by_sample(inst)) { bind_insts.push_back(inst.get()); @@ -122,12 +122,12 @@ std::vector MutateThreadBindingNode::FindCan } for (const InstructionNode* bind_inst : bind_insts) { - const auto* loop_rv = TVM_TYPE_AS(loop_rv, bind_inst->inputs[0], tir::LoopRVNode); + const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode); auto split_it = sampled_split_insts.find(loop_rv); ICHECK(split_it != sampled_split_insts.end()); const InstructionNode* split_inst = split_it->second; - const auto* expr_rv = TVM_TYPE_AS(expr_rv, split_inst->inputs[2], PrimExprNode); + const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode); auto sample_it = sample_insts.find(expr_rv); ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 00967aef7acd..4a3bfda8a4a8 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -34,7 +34,7 @@ using tir::Trace; * \return The result of downcast */ std::vector DowncastTilingDecision(const ObjectRef& decision) { - const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); + const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode); return support::AsVector(GetRef>(arr)); } @@ -123,7 +123,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { - const auto* d = TVM_TYPE_AS(d, decision, IntImmNode); + const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 94e83488584e..c282a171c3b7 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -91,7 +91,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); - const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (IsAnnotateWithUnroll(inst)) { ann_insts.push_back(inst.get()); @@ -103,7 +103,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, } const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; ICHECK_EQ(ann_inst->inputs.size(), 2); - const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode); ICHECK(sample_insts.count(var_rv)); const InstructionNode* sample_inst = sample_insts.at(var_rv); ICHECK_EQ(sample_inst->attrs.size(), 2); diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index f3c2b1328bc3..08d25d017840 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -233,7 +233,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, int64_t prod_extent = 1; for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { const StmtSRef& loop_sref = loop_srefs[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (HasAnnOrBinding(loop)) { break; } @@ -262,7 +262,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, for (int i = n_loops - 1; i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { const StmtSRef& loop_sref = loop_srefs[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (HasAnnOrBinding(loop)) { break; } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index ff4d26084e57..d8f52fa8e1de 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -45,7 +45,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, int i_spatial_loop = -1; for (int i = 0; i < n; ++i) { const StmtSRef& loop_sref = loops[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); runtime::ThreadScope thread_scope = GetThreadScope(loop); if (IsBlockIdx(thread_scope)) { if (i_block_idx == -1) { diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index df4d3ac85911..76313f46d1c8 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -96,7 +96,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, StmtSRef block_sref = sch->GetSRef(block_rv); bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref); ScheduleState state = sch->state(); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); BlockRealize realize = GetBlockRealize(state, block_sref); // Cond 1. The block has only one write buffer if (block->writes.size() != 1) { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index eefc2eea411b..c126c854462c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -37,7 +37,7 @@ namespace tir { * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header */ std::vector GetReadBufferNDims(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BufferNode* write_buffer = block->writes[0]->buffer.get(); int n = block->reads.size(); std::vector results(n, -1); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 49704fb66b15..7ddda9b2635b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -411,7 +411,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Hold the reference of the block before reindex const tir::Block block_before_reindex = GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { @@ -488,7 +488,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( } visited_buffers.insert(lhs_buffer); // Refresh block pointer (block sref is not invalidated) - block = TVM_SREF_TO_BLOCK(block, block_sref); + block = TVM_SREF_TO_BLOCK(block_sref); const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index e4b5d5bde256..65988dfd5688 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -60,7 +60,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { private: bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); // Cond 1. The block is not the root block. if (block_sref->parent == nullptr) { diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index cb84596eed11..664a6a609e7f 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -238,7 +238,7 @@ inline std::string Concat(const Array& strs, const std::string& delim) { */ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, const String& global_var_name) { - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return sch->GetBlock(block->name_hint, global_var_name); } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 62ec0b468f9d..b9e99257f37c 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -150,7 +150,7 @@ Definition of a scope that is a stage pipeline: if (require_stage_pipeline) { bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; if (stage_pipeline == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); throw NotStagePipelineError(self->mod, GetRef(block)); } } @@ -229,7 +229,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, } } // Check whether the input block is the only writer of its outputs - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const BufferRegion& write_region : block->writes) { if (buffer_writers.count(write_region->buffer)) { if (buffer_writers.at(write_region->buffer).size() != 1) { @@ -252,7 +252,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { // Cond 1. All block vars are data parallel - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { return 1; @@ -328,7 +328,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw IncompleteBlockError(self->mod, GetRef(block), error_code); } } @@ -344,7 +344,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, */ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { return 1; @@ -394,7 +394,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotReductionBlockError(self->mod, GetRef(block), error_code); } } @@ -441,7 +441,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl if (reduction_block_error_code == 0) { return; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotCompleteOrReductionBlockError(self->mod, GetRef(block), complete_block_error_code, reduction_block_error_code); } @@ -491,7 +491,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); if (local_complete_block_code != 0 && local_reduction_block_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), GetRef(block), local_complete_block_code, local_reduction_block_code); @@ -501,8 +501,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); std::unordered_set scope_allocated; scope_allocated.reserve(scope_root->alloc_buffers.size()); for (const Buffer& buffer : scope_root->alloc_buffers) { @@ -532,7 +532,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, Block block_; }; if (IsOutputBlock(self, block_sref, scope_root_sref)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw OutputBlockError(self->mod, GetRef(block)); } } @@ -547,12 +547,12 @@ std::vector GetBlockVarTypes(const BlockNode* block) { } std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return GetBlockVarTypes(block); } bool IsWriteCache(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block->writes.size() != 1) { return false; } @@ -751,7 +751,7 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre IRModule mod_; For loop_; }; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!analyzer->CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } @@ -856,7 +856,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr const BlockRealizeNode* result; }; - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block_sref->parent == nullptr) { const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); return Downcast(func->body); @@ -870,7 +870,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } IterVarType GetLoopIterType(const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const Var& loop_var = loop->loop_var; int n_spatial = 0; int n_reduce = 0; @@ -1924,7 +1924,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } bool IsSpatial(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != IterVarType::kDataPar) { return false; @@ -1934,14 +1934,14 @@ bool IsSpatial(const StmtSRef& block_sref) { } bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); Array loops = GetLoops(block_sref); Array binds = GetBlockRealize(self, block_sref)->iter_values; if (loops.size() != binds.size()) { return false; } for (int i = 0, n = loops.size(); i < n; ++i) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + const ForNode* loop = TVM_SREF_TO_FOR(loops[i]); if (binds[i].get() != loop->loop_var.get()) { return false; } @@ -1953,7 +1953,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref if (HasBeenMultiLevelTiled(block_sref)) { return false; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || !IsTrivialBinding(self, block_sref)) { return false; @@ -2065,7 +2065,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const tir::StmtSRef& block_sref, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Array loops = tir::GetLoops(block_sref); // Cond 1. The block has only one write buffer @@ -2100,9 +2100,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } // Cond 5. - const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]); if (i < loops.size() - 1) { - const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]); if (loop_i->body.get() != loop_i1) { return false; } @@ -2194,7 +2194,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); - const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + TVM_SREF_TO_BLOCK(scope_sref); std::vector block_loops; std::unordered_set block_loop_vars; { diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc index f1ce65e48e03..31452f4a8f15 100644 --- a/src/tir/schedule/block_scope.cc +++ b/src/tir/schedule/block_scope.cc @@ -76,7 +76,7 @@ BlockScope::BlockScope(const Array& child_block_srefs) { SMap> buffer_readers; SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { - const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref); + const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer for (const BufferRegion& region : child_block->reads) { buffer_readers[region->buffer].push_back(child_block_sref); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c16638f748b4..32fe01a6f55e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -269,7 +269,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional(block)); } } @@ -432,7 +432,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index cdd0a5b7b0a2..42da373f1bf2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -205,13 +205,13 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); return GetRef(loop); } @@ -222,7 +222,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; } const ObjectRef& obj = (*it).second; - const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); + const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode); return Integer(int_imm->value); }); return this->analyzer_->Simplify(transformed); diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 2d876d9bf7fa..31c938313fed 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -238,7 +238,7 @@ class StorageScopeMutator : private ReplaceBufferMutator { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, BufferIndexType::kWrite); StorageAlignInvalidFactorError::Check(self->mod, factor); @@ -274,7 +274,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const String& storage_scope) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); @@ -289,7 +289,7 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, // Step 3. Get the allocation site of the target buffer. StmtSRef alloc_site_sref = NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); - const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site, alloc_site_sref); + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given // storage scope. In the meanwhile, collect the block sref reuse information. diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index cf6532e82d46..7481a7c92494 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -426,7 +426,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, Map* block_sref_reuse, arith::Analyzer* analyzer) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); Block block = block_realize->block; diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 529d3333cd18..a221733eb394 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -31,7 +31,7 @@ class NotSingleWriteBlock : public ScheduleError { ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); write_blocks_.push_back(GetRef(block)); } } @@ -532,7 +532,7 @@ class CacheReadRewriter : public StmtExprMutator { bool is_consumer = info_->consumer_blocks.empty(); // Otherwise check if this is one of the specified blocks. for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_node, consumer_sref); + const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); Block consumer_block = GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { is_consumer = true; @@ -999,11 +999,11 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer read_buffer = GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo CacheStageInfo info; @@ -1020,7 +1020,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); - const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); // Find the producing region BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); StmtSRef parent_sref = GetRef(write_block_sref->parent); @@ -1072,7 +1072,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer write_buffer = GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -1114,7 +1114,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Block block = GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 98a6b2400ee3..0acbdae6924c 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -37,7 +37,7 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); required_.push_back(GetRef(block)); } } @@ -288,14 +288,14 @@ class ScopeReconstructor : private StmtMutator { return GetRef(block); } if (block == rm_src_stmt_.get()) { - block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode); + block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); } return StmtMutator::VisitStmt_(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == rm_src_stmt_.get()) { - loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode); + loop = TVM_TYPE_AS(rm_tgt_stmt_, ForNode); } if (loop == loop_.get()) { return new_loop_; @@ -541,7 +541,7 @@ void CalculateProvidedRequiredRegions( } // Step 2. Calculate the region required by dependent blocks under `loop` for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { - const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref); + const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), @@ -557,8 +557,8 @@ template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, arith::Analyzer* analyzer, bool check_only = false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Step 1. Bunch of checks // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index ad15e06e285a..bfda66036fe3 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -174,7 +174,7 @@ class NonSingleProducerError : public ScheduleError { } } } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, consumer_block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref); throw NonSingleProducerError(self->mod, GetRef(block)); } }; @@ -183,7 +183,7 @@ class OpaqueAccessError : public ScheduleError { public: explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) : mod_(mod), scope_root_(nullptr) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); this->scope_root_ = GetRef(scope_root); } @@ -653,7 +653,7 @@ class ReverseComputeInliner : public BaseInliner { void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { - const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); + const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); Block producer_block = GetRef(_producer_block); HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -698,7 +698,7 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, bool check_only = false) { - const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); + const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); Block consumer_block = GetRef(_consumer_block); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 365c6d43f127..93fb88e66619 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -415,7 +415,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, * - trim original block to write non-padding part only */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); Map dom_map; arith::Analyzer analyzer; diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index ec337224e59d..cc8cb55fd3fa 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -145,7 +145,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind */ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, Optional thread_axis) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); /* * Check: @@ -186,7 +186,7 @@ void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_a } void Unroll(ScheduleState self, const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); ObjectPtr new_loop = make_object(*loop); new_loop->kind = ForKind::kUnrolled; new_loop->thread_binding = NullOpt; diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 746918ac4e34..cbdb99c6444f 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -40,7 +40,7 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G }; BaseFunc func = self->mod->Lookup(gv); - const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode); Finder finder(self, name); finder(prim_func->body); return std::move(finder.results_); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 148b3ee033c3..b4e40fa120fe 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -134,7 +134,7 @@ class BufferIsSubregionError : public ScheduleError { void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); Optional defining_site_sref; @@ -147,7 +147,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 1: Infer the shape of the new buffer ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); @@ -344,7 +344,7 @@ class OpaqueNewIterTypeError : public ScheduleError { void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); const Block& block = GetRef(block_ptr); arith::Analyzer analyzer; @@ -489,7 +489,7 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); Optional defining_site_sref; @@ -502,7 +502,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 1: Check and update axis_separators of the buffer. Buffer new_buffer = old_buffer; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index f1b6f46e1b8f..2db3eb902aba 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -87,7 +87,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { bool preserve_unit_iters) { Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } return Downcast(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent), @@ -389,7 +389,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array // - The execution order has not changed. (The block executes with the same args and the same // order with before. // Step 1. Check correctness - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } @@ -445,7 +445,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array result_srefs.reserve(n); for (int i = 0; i < n; i++) { result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); - const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + const ForNode* outer_loop = TVM_TYPE_AS(new_stmt, ForNode); new_stmt = outer_loop->body; } return result_srefs; @@ -464,7 +464,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser std::unordered_set outer_loop_vars; // Step 1. check correctness for (const StmtSRef& sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } @@ -554,7 +554,7 @@ std::unordered_set CollectLoopsIntoSet( for (const StmtSRef& loop_sref : ordered_loop_srefs) { auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); throw LoopMultiAppearanceError(self->mod, GetRef(loop)); } } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index ad9043e4f2db..7a4ace736e48 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -123,7 +123,7 @@ class LoopHeightError : public ScheduleError { // loop_var of a higher loop shouldn't contain loop var const Var& loop_var = higher_loop->StmtAs()->loop_var; if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); throw LoopHeightError(mod, GetRef(loop), GetRef(block)); } } @@ -183,8 +183,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, * - generate corresponding init block and update block */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Get the outer loops from high to low Array loops = GetLoops(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); @@ -264,7 +264,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, std::unordered_map loop_var_map; Stmt body = BlockRealize(init_realize); for (int i : chosen_loops) { - const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]); + const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); // Create a new equivalent to the chosen loop Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); @@ -277,7 +277,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, } body = Substitute(body, loop_var_map); // Step 6. Mutate IR - const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref); + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); Block new_scope_root{nullptr}; Block new_reduction_block{nullptr}; std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace( @@ -1013,7 +1013,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax StmtSRef scope_root = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/true); CheckReductionBlock(self, block_sref, scope_root); - const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); + const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 1961565aac75..52b5add2bc9e 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -311,7 +311,7 @@ std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; if (extent == nullptr) { diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 07481ddb19e3..15d0e08ddc2c 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -208,7 +208,7 @@ class BlockInfoCollector : private StmtVisitor { if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, // then we set the affine binding flag as true only if the block has no block vars - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root); + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); if (block->iter_vars.empty()) info.affine_binding = true; } else { info.affine_binding = @@ -233,7 +233,7 @@ class BlockInfoCollector : private StmtVisitor { block_reads_unbound.reserve(child_block_srefs.size()); block_writes_unbound.reserve(child_block_srefs.size()); for (const StmtSRef& block_sref : child_block_srefs) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Map binding = GetBindings(block2realize_.at(block)); // Step 1.1. Unbind read regions Array reads; @@ -254,7 +254,7 @@ class BlockInfoCollector : private StmtVisitor { for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; const Array& deps = kv.second; - const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref); + const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); const BlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; // Step 2.1. Extract the path to the scope root @@ -851,7 +851,7 @@ class ChildReplacer : private StmtMutator { } else if (const auto* realize = stmt.as()) { // Case 2. stmt is BlockRealize, src_stmt is Block if (realize->block.get() == src_stmt) { - const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode); + const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); ObjectPtr new_realize = make_object(*realize); new_realize->block = GetRef(tgt_block); new_stmt = BlockRealize(std::move(new_realize)); @@ -1044,9 +1044,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // If `g_func` was unique, after the 3 lines above: // `ref_new_func` points to the same unique function that `g_func` points to // Update the body of the function the sref belongs to Assign - const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode); + const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); // Make `child_tgt_stmt` the root block - const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode); + const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); ObjectPtr new_realize = make_object(*realize); new_realize->block = GetRef(child_block); new_func->body = BlockRealize(std::move(new_realize)); @@ -1078,7 +1078,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 1c21d770db30..1ebaf202d487 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -36,7 +36,7 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec Buffer WithScope(const Buffer& buffer, const String& scope) { ObjectPtr new_buffer = make_object(*buffer.get()); ObjectPtr new_var = make_object(*buffer->data.get()); - const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode); + const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, scope); new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); new_buffer->name = buffer->name + "_" + scope; @@ -253,8 +253,8 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } } ICHECK(sref != nullptr && sref->stmt != nullptr); - const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block, leaf_block_sref); - const auto* scope_block = TVM_SREF_TO_BLOCK(scope_block, sref); + const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); + const auto* scope_block = TVM_SREF_TO_BLOCK(sref); throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 3db80989ae10..c289309acc2d 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -62,25 +62,35 @@ namespace tir { /*! * \brief A helper macro to convert an sref to the block it points to, - * throwing an internal error if downcasting fails - * \param Result The result variable, used for checking + * + * Throws an internal error if downcasting fails. The variable name + * in the parent scope is used for the error message. + * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_BLOCK(Result, SRef) \ - TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ - << "TypeError: Expects StmtSRef `" << #SRef \ - << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") +#define TVM_SREF_TO_BLOCK(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \ + << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ + }() /*! - * \brief A helper macro to convert an sref to the for-loop it points to, - * throwing an internal error if downcasting fails - * \param Result The name of the result variable, used for checking + * \brief A helper macro to convert an sref to the for-loop it points to + * + * Throws an internal error if downcasting fails. The variable name + * in the parent scope is used for the error message. + * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_FOR(Result, SRef) \ - TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ - << "TypeError: Expects StmtSRef `" << #SRef \ - << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") +#define TVM_SREF_TO_FOR(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ + << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ + }() /*! * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, @@ -100,10 +110,13 @@ namespace tir { * \param From The ObjectRef to be downcast * \param Type The type to be downcast to */ -#define TVM_TYPE_AS(Result, From, Type) \ - TVM_TYPE_AS_OR_ERR(Result, From, Type) \ - << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ - << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +#define TVM_TYPE_AS(From, Type) \ + [&]() { \ + auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \ + << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ + << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \ + return result; \ + }() /*! * \brief Convert an array of loop StmtSRefs to an array of loops @@ -114,7 +127,7 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { Array loops; loops.reserve(loop_srefs.size()); for (StmtSRef loop_sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); loops.push_back(GetRef(loop)); } return loops; @@ -264,7 +277,7 @@ inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_in * \return The extent of the loop, nullptr if the extent is not constant */ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); return as_const_int(loop->extent); }