diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 5b7fe7f5148d..82b91da682c6 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -64,7 +64,7 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { return nullptr; } ICHECK_EQ(inst->outputs.size(), 1); - const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode); return block; } @@ -82,7 +82,7 @@ std::vector> AnalyzeParallel(const ScheduleState& self, Array block_srefs = tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); std::vector> results; results.reserve(info.realizes.size()); diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 41207162ee1d..de780b53e2d9 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -109,12 +109,12 @@ std::vector MutateThreadBindingNode::FindCan for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); - const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (is_split_by_sample(inst)) { CHECK_EQ(inst->outputs.size(), 2); // Only consider the inner loop, which can be bound to threadIdx.x - const tir::LoopRVNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[1], tir::LoopRVNode); + const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode); sampled_split_insts[var_rv] = inst.get(); } else if (is_thread_binding_by_sample(inst)) { bind_insts.push_back(inst.get()); @@ -122,12 +122,12 @@ std::vector MutateThreadBindingNode::FindCan } for (const InstructionNode* bind_inst : bind_insts) { - const auto* loop_rv = TVM_TYPE_AS(loop_rv, bind_inst->inputs[0], tir::LoopRVNode); + const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode); auto split_it = sampled_split_insts.find(loop_rv); ICHECK(split_it != sampled_split_insts.end()); const InstructionNode* split_inst = split_it->second; - const auto* expr_rv = TVM_TYPE_AS(expr_rv, split_inst->inputs[2], PrimExprNode); + const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode); auto sample_it = sample_insts.find(expr_rv); ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 00967aef7acd..4a3bfda8a4a8 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -34,7 +34,7 @@ using tir::Trace; * \return The result of downcast */ std::vector DowncastTilingDecision(const ObjectRef& decision) { - const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); + const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode); return support::AsVector(GetRef>(arr)); } @@ -123,7 +123,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { - const auto* d = TVM_TYPE_AS(d, decision, IntImmNode); + const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 94e83488584e..c282a171c3b7 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -91,7 +91,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { ICHECK_EQ(inst->outputs.size(), 1); - const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (IsAnnotateWithUnroll(inst)) { ann_insts.push_back(inst.get()); @@ -103,7 +103,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, } const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; ICHECK_EQ(ann_inst->inputs.size(), 2); - const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode); ICHECK(sample_insts.count(var_rv)); const InstructionNode* sample_inst = sample_insts.at(var_rv); ICHECK_EQ(sample_inst->attrs.size(), 2); diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index f3c2b1328bc3..08d25d017840 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -233,7 +233,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, int64_t prod_extent = 1; for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { const StmtSRef& loop_sref = loop_srefs[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (HasAnnOrBinding(loop)) { break; } @@ -262,7 +262,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, for (int i = n_loops - 1; i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { const StmtSRef& loop_sref = loop_srefs[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (HasAnnOrBinding(loop)) { break; } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index ff4d26084e57..d8f52fa8e1de 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -45,7 +45,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, int i_spatial_loop = -1; for (int i = 0; i < n; ++i) { const StmtSRef& loop_sref = loops[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); runtime::ThreadScope thread_scope = GetThreadScope(loop); if (IsBlockIdx(thread_scope)) { if (i_block_idx == -1) { diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index df4d3ac85911..76313f46d1c8 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -96,7 +96,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, StmtSRef block_sref = sch->GetSRef(block_rv); bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref); ScheduleState state = sch->state(); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); BlockRealize realize = GetBlockRealize(state, block_sref); // Cond 1. The block has only one write buffer if (block->writes.size() != 1) { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index eefc2eea411b..c126c854462c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -37,7 +37,7 @@ namespace tir { * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header */ std::vector GetReadBufferNDims(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BufferNode* write_buffer = block->writes[0]->buffer.get(); int n = block->reads.size(); std::vector results(n, -1); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 49704fb66b15..7ddda9b2635b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -411,7 +411,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Hold the reference of the block before reindex const tir::Block block_before_reindex = GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { @@ -488,7 +488,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( } visited_buffers.insert(lhs_buffer); // Refresh block pointer (block sref is not invalidated) - block = TVM_SREF_TO_BLOCK(block, block_sref); + block = TVM_SREF_TO_BLOCK(block_sref); const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index e4b5d5bde256..65988dfd5688 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -60,7 +60,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { private: bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); // Cond 1. The block is not the root block. if (block_sref->parent == nullptr) { diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index cb84596eed11..664a6a609e7f 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -238,7 +238,7 @@ inline std::string Concat(const Array& strs, const std::string& delim) { */ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, const String& global_var_name) { - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return sch->GetBlock(block->name_hint, global_var_name); } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 62ec0b468f9d..b9e99257f37c 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -150,7 +150,7 @@ Definition of a scope that is a stage pipeline: if (require_stage_pipeline) { bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; if (stage_pipeline == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); throw NotStagePipelineError(self->mod, GetRef(block)); } } @@ -229,7 +229,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, } } // Check whether the input block is the only writer of its outputs - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const BufferRegion& write_region : block->writes) { if (buffer_writers.count(write_region->buffer)) { if (buffer_writers.at(write_region->buffer).size() != 1) { @@ -252,7 +252,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { // Cond 1. All block vars are data parallel - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { return 1; @@ -328,7 +328,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw IncompleteBlockError(self->mod, GetRef(block), error_code); } } @@ -344,7 +344,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, */ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { return 1; @@ -394,7 +394,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotReductionBlockError(self->mod, GetRef(block), error_code); } } @@ -441,7 +441,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl if (reduction_block_error_code == 0) { return; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotCompleteOrReductionBlockError(self->mod, GetRef(block), complete_block_error_code, reduction_block_error_code); } @@ -491,7 +491,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); if (local_complete_block_code != 0 && local_reduction_block_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), GetRef(block), local_complete_block_code, local_reduction_block_code); @@ -501,8 +501,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); std::unordered_set scope_allocated; scope_allocated.reserve(scope_root->alloc_buffers.size()); for (const Buffer& buffer : scope_root->alloc_buffers) { @@ -532,7 +532,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, Block block_; }; if (IsOutputBlock(self, block_sref, scope_root_sref)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); throw OutputBlockError(self->mod, GetRef(block)); } } @@ -547,12 +547,12 @@ std::vector GetBlockVarTypes(const BlockNode* block) { } std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return GetBlockVarTypes(block); } bool IsWriteCache(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block->writes.size() != 1) { return false; } @@ -751,7 +751,7 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre IRModule mod_; For loop_; }; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!analyzer->CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } @@ -856,7 +856,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr const BlockRealizeNode* result; }; - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block_sref->parent == nullptr) { const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); return Downcast(func->body); @@ -870,7 +870,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } IterVarType GetLoopIterType(const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const Var& loop_var = loop->loop_var; int n_spatial = 0; int n_reduce = 0; @@ -1924,7 +1924,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } bool IsSpatial(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != IterVarType::kDataPar) { return false; @@ -1934,14 +1934,14 @@ bool IsSpatial(const StmtSRef& block_sref) { } bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); Array loops = GetLoops(block_sref); Array binds = GetBlockRealize(self, block_sref)->iter_values; if (loops.size() != binds.size()) { return false; } for (int i = 0, n = loops.size(); i < n; ++i) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + const ForNode* loop = TVM_SREF_TO_FOR(loops[i]); if (binds[i].get() != loop->loop_var.get()) { return false; } @@ -1953,7 +1953,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref if (HasBeenMultiLevelTiled(block_sref)) { return false; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || !IsTrivialBinding(self, block_sref)) { return false; @@ -2065,7 +2065,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const tir::StmtSRef& block_sref, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Array loops = tir::GetLoops(block_sref); // Cond 1. The block has only one write buffer @@ -2100,9 +2100,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } // Cond 5. - const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]); if (i < loops.size() - 1) { - const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]); if (loop_i->body.get() != loop_i1) { return false; } @@ -2194,7 +2194,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); - const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + TVM_SREF_TO_BLOCK(scope_sref); std::vector block_loops; std::unordered_set block_loop_vars; { diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc index f1ce65e48e03..31452f4a8f15 100644 --- a/src/tir/schedule/block_scope.cc +++ b/src/tir/schedule/block_scope.cc @@ -76,7 +76,7 @@ BlockScope::BlockScope(const Array& child_block_srefs) { SMap> buffer_readers; SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { - const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref); + const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer for (const BufferRegion& region : child_block->reads) { buffer_readers[region->buffer].push_back(child_block_sref); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 5f773a02d6ff..afc675799706 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -269,7 +269,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional(block)); } } @@ -432,7 +432,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 92b9de408873..e79d1d528809 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -206,13 +206,13 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); return GetRef(loop); } @@ -223,7 +223,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; } const ObjectRef& obj = (*it).second; - const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); + const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode); return Integer(int_imm->value); }); return this->analyzer_->Simplify(transformed); diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 2d876d9bf7fa..31c938313fed 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -238,7 +238,7 @@ class StorageScopeMutator : private ReplaceBufferMutator { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, BufferIndexType::kWrite); StorageAlignInvalidFactorError::Check(self->mod, factor); @@ -274,7 +274,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const String& storage_scope) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); @@ -289,7 +289,7 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, // Step 3. Get the allocation site of the target buffer. StmtSRef alloc_site_sref = NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); - const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site, alloc_site_sref); + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given // storage scope. In the meanwhile, collect the block sref reuse information. diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index cf6532e82d46..7481a7c92494 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -426,7 +426,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, Map* block_sref_reuse, arith::Analyzer* analyzer) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); Block block = block_realize->block; diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 529d3333cd18..a221733eb394 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -31,7 +31,7 @@ class NotSingleWriteBlock : public ScheduleError { ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); write_blocks_.push_back(GetRef(block)); } } @@ -532,7 +532,7 @@ class CacheReadRewriter : public StmtExprMutator { bool is_consumer = info_->consumer_blocks.empty(); // Otherwise check if this is one of the specified blocks. for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_node, consumer_sref); + const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); Block consumer_block = GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { is_consumer = true; @@ -999,11 +999,11 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer read_buffer = GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo CacheStageInfo info; @@ -1020,7 +1020,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); - const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); // Find the producing region BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); StmtSRef parent_sref = GetRef(write_block_sref->parent); @@ -1072,7 +1072,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer write_buffer = GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -1114,7 +1114,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Block block = GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 8baedfd70dd0..83342e351b91 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -37,7 +37,7 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); required_.push_back(GetRef(block)); } } @@ -306,14 +306,14 @@ class ScopeReconstructor : private StmtMutator { return GetRef(block); } if (block == rm_src_stmt_.get()) { - block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode); + block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); } return StmtMutator::VisitStmt_(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == rm_src_stmt_.get()) { - loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode); + loop = TVM_TYPE_AS(rm_tgt_stmt_, ForNode); } if (loop == loop_.get()) { return new_loop_; @@ -559,7 +559,7 @@ void CalculateProvidedRequiredRegions( } // Step 2. Calculate the region required by dependent blocks under `loop` for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { - const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref); + const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), @@ -576,8 +576,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s const StmtSRef& loop_sref, bool preserve_unit_loops, arith::Analyzer* analyzer, bool check_only = false, int index = -1) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Step 1. Bunch of checks // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index ad15e06e285a..bfda66036fe3 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -174,7 +174,7 @@ class NonSingleProducerError : public ScheduleError { } } } - const BlockNode* block = TVM_SREF_TO_BLOCK(block, consumer_block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref); throw NonSingleProducerError(self->mod, GetRef(block)); } }; @@ -183,7 +183,7 @@ class OpaqueAccessError : public ScheduleError { public: explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) : mod_(mod), scope_root_(nullptr) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); this->scope_root_ = GetRef(scope_root); } @@ -653,7 +653,7 @@ class ReverseComputeInliner : public BaseInliner { void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { - const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); + const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); Block producer_block = GetRef(_producer_block); HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -698,7 +698,7 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, bool check_only = false) { - const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); + const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); Block consumer_block = GetRef(_consumer_block); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 365c6d43f127..93fb88e66619 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -415,7 +415,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, * - trim original block to write non-padding part only */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); Map dom_map; arith::Analyzer analyzer; diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index ec337224e59d..cc8cb55fd3fa 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -145,7 +145,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind */ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, Optional thread_axis) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); /* * Check: @@ -186,7 +186,7 @@ void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_a } void Unroll(ScheduleState self, const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); ObjectPtr new_loop = make_object(*loop); new_loop->kind = ForKind::kUnrolled; new_loop->thread_binding = NullOpt; diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 746918ac4e34..cbdb99c6444f 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -40,7 +40,7 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G }; BaseFunc func = self->mod->Lookup(gv); - const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode); Finder finder(self, name); finder(prim_func->body); return std::move(finder.results_); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 148b3ee033c3..b4e40fa120fe 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -134,7 +134,7 @@ class BufferIsSubregionError : public ScheduleError { void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); Optional defining_site_sref; @@ -147,7 +147,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 1: Infer the shape of the new buffer ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); @@ -344,7 +344,7 @@ class OpaqueNewIterTypeError : public ScheduleError { void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); const Block& block = GetRef(block_ptr); arith::Analyzer analyzer; @@ -489,7 +489,7 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); Optional defining_site_sref; @@ -502,7 +502,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 1: Check and update axis_separators of the buffer. Buffer new_buffer = old_buffer; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index f1b6f46e1b8f..2db3eb902aba 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -87,7 +87,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { bool preserve_unit_iters) { Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } return Downcast(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent), @@ -389,7 +389,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array // - The execution order has not changed. (The block executes with the same args and the same // order with before. // Step 1. Check correctness - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } @@ -445,7 +445,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array result_srefs.reserve(n); for (int i = 0; i < n; i++) { result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); - const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + const ForNode* outer_loop = TVM_TYPE_AS(new_stmt, ForNode); new_stmt = outer_loop->body; } return result_srefs; @@ -464,7 +464,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser std::unordered_set outer_loop_vars; // Step 1. check correctness for (const StmtSRef& sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } @@ -554,7 +554,7 @@ std::unordered_set CollectLoopsIntoSet( for (const StmtSRef& loop_sref : ordered_loop_srefs) { auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); throw LoopMultiAppearanceError(self->mod, GetRef(loop)); } } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index ad9043e4f2db..7a4ace736e48 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -123,7 +123,7 @@ class LoopHeightError : public ScheduleError { // loop_var of a higher loop shouldn't contain loop var const Var& loop_var = higher_loop->StmtAs()->loop_var; if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); throw LoopHeightError(mod, GetRef(loop), GetRef(block)); } } @@ -183,8 +183,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, * - generate corresponding init block and update block */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Get the outer loops from high to low Array loops = GetLoops(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); @@ -264,7 +264,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, std::unordered_map loop_var_map; Stmt body = BlockRealize(init_realize); for (int i : chosen_loops) { - const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]); + const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); // Create a new equivalent to the chosen loop Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); @@ -277,7 +277,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, } body = Substitute(body, loop_var_map); // Step 6. Mutate IR - const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref); + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); Block new_scope_root{nullptr}; Block new_reduction_block{nullptr}; std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace( @@ -1013,7 +1013,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax StmtSRef scope_root = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/true); CheckReductionBlock(self, block_sref, scope_root); - const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); + const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 1961565aac75..52b5add2bc9e 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -311,7 +311,7 @@ std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; if (extent == nullptr) { diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 07481ddb19e3..15d0e08ddc2c 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -208,7 +208,7 @@ class BlockInfoCollector : private StmtVisitor { if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, // then we set the affine binding flag as true only if the block has no block vars - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root); + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); if (block->iter_vars.empty()) info.affine_binding = true; } else { info.affine_binding = @@ -233,7 +233,7 @@ class BlockInfoCollector : private StmtVisitor { block_reads_unbound.reserve(child_block_srefs.size()); block_writes_unbound.reserve(child_block_srefs.size()); for (const StmtSRef& block_sref : child_block_srefs) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Map binding = GetBindings(block2realize_.at(block)); // Step 1.1. Unbind read regions Array reads; @@ -254,7 +254,7 @@ class BlockInfoCollector : private StmtVisitor { for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; const Array& deps = kv.second; - const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref); + const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); const BlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; // Step 2.1. Extract the path to the scope root @@ -851,7 +851,7 @@ class ChildReplacer : private StmtMutator { } else if (const auto* realize = stmt.as()) { // Case 2. stmt is BlockRealize, src_stmt is Block if (realize->block.get() == src_stmt) { - const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode); + const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); ObjectPtr new_realize = make_object(*realize); new_realize->block = GetRef(tgt_block); new_stmt = BlockRealize(std::move(new_realize)); @@ -1044,9 +1044,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // If `g_func` was unique, after the 3 lines above: // `ref_new_func` points to the same unique function that `g_func` points to // Update the body of the function the sref belongs to Assign - const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode); + const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); // Make `child_tgt_stmt` the root block - const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode); + const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); ObjectPtr new_realize = make_object(*realize); new_realize->block = GetRef(child_block); new_func->body = BlockRealize(std::move(new_realize)); @@ -1078,7 +1078,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + TVM_SREF_TO_BLOCK(block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 1c21d770db30..1ebaf202d487 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -36,7 +36,7 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec Buffer WithScope(const Buffer& buffer, const String& scope) { ObjectPtr new_buffer = make_object(*buffer.get()); ObjectPtr new_var = make_object(*buffer->data.get()); - const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode); + const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, scope); new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); new_buffer->name = buffer->name + "_" + scope; @@ -253,8 +253,8 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } } ICHECK(sref != nullptr && sref->stmt != nullptr); - const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block, leaf_block_sref); - const auto* scope_block = TVM_SREF_TO_BLOCK(scope_block, sref); + const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); + const auto* scope_block = TVM_SREF_TO_BLOCK(sref); throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 3db80989ae10..c289309acc2d 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -62,25 +62,35 @@ namespace tir { /*! * \brief A helper macro to convert an sref to the block it points to, - * throwing an internal error if downcasting fails - * \param Result The result variable, used for checking + * + * Throws an internal error if downcasting fails. The variable name + * in the parent scope is used for the error message. + * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_BLOCK(Result, SRef) \ - TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ - << "TypeError: Expects StmtSRef `" << #SRef \ - << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") +#define TVM_SREF_TO_BLOCK(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \ + << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ + }() /*! - * \brief A helper macro to convert an sref to the for-loop it points to, - * throwing an internal error if downcasting fails - * \param Result The name of the result variable, used for checking + * \brief A helper macro to convert an sref to the for-loop it points to + * + * Throws an internal error if downcasting fails. The variable name + * in the parent scope is used for the error message. + * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_FOR(Result, SRef) \ - TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ - << "TypeError: Expects StmtSRef `" << #SRef \ - << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None") +#define TVM_SREF_TO_FOR(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ + << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ + }() /*! * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, @@ -100,10 +110,13 @@ namespace tir { * \param From The ObjectRef to be downcast * \param Type The type to be downcast to */ -#define TVM_TYPE_AS(Result, From, Type) \ - TVM_TYPE_AS_OR_ERR(Result, From, Type) \ - << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ - << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +#define TVM_TYPE_AS(From, Type) \ + [&]() { \ + auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \ + << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ + << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \ + return result; \ + }() /*! * \brief Convert an array of loop StmtSRefs to an array of loops @@ -114,7 +127,7 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { Array loops; loops.reserve(loop_srefs.size()); for (StmtSRef loop_sref : loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); loops.push_back(GetRef(loop)); } return loops; @@ -264,7 +277,7 @@ inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_in * \return The extent of the loop, nullptr if the extent is not constant */ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); return as_const_int(loop->extent); }