From 276324673ae54dff686f88c1f394f9817e455c34 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 02:08:31 -0800 Subject: [PATCH 01/11] refactor --- include/tvm/tir/analysis.h | 53 ++++++ src/tir/analysis/var_use_def_analysis.cc | 215 +++++++++++++++++++++++ src/tir/transforms/split_host_device.cc | 201 --------------------- 3 files changed, 268 insertions(+), 201 deletions(-) create mode 100644 src/tir/analysis/var_use_def_analysis.cc diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a8edc2675fc4..45f5f6071849 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -30,6 +30,7 @@ #include #include #include +#include #include @@ -86,6 +87,58 @@ TVM_DLL double EstimateTIRFlops(const Stmt& stmt); */ TVM_DLL double EstimateTIRFlops(const IRModule& mod); +/*! + * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. + * \sa UndefinedVars + */ +class VarUseDefAnalysis : public StmtExprMutator { + public: + // The fields are publically readible to + // be accessible to the users. + bool visit_thread_extent_{true}; + bool simplify_let_{true}; + Array undefined_; + Array thread_axis_; + Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; + std::unordered_map use_count_; + std::unordered_map def_count_; + + private: + ExprDeepEqual deep_equal_; + std::unordered_map let_binding_; + Stmt VisitStmt_(const AttrStmtNode* op) final; + + Stmt VisitStmt_(const LetStmtNode* op) final; + + Stmt VisitStmt_(const ForNode* op) final; + + Stmt VisitStmt_(const AllocateNode* op) final; + + Stmt VisitStmt_(const AllocateConstNode* op) final; + + Stmt VisitStmt_(const StoreNode* op) final; + + Stmt VisitStmt_(const BufferStoreNode* op) final; + + PrimExpr VisitExpr_(const LetNode* op) final; + + PrimExpr VisitExpr_(const VarNode* op) final; + + PrimExpr VisitExpr_(const ReduceNode* op) final; + + PrimExpr VisitExpr_(const LoadNode* op) final; + + PrimExpr VisitExpr_(const BufferLoadNode* op) final; + + void HandleDef(const VarNode* v); + + void HandleUse(const PrimExpr& v); + + void VisitBuffer(Buffer buffer); +}; + /*! * \brief Find undefined vars in the statement. * \param stmt The function to be checked. diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc new file mode 100644 index 000000000000..7370a1bff16e --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file var_use_def_analysis.cc + * \brief Classes and functions to analyze var defition and usage. + */ +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../transforms/ir_utils.h" + +namespace tvm { +namespace tir { + +Stmt VarUseDefAnalysis::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!use_count_.count(iv->var.get())) { + this->HandleDef(iv->var.get()); + thread_axis_.push_back(iv); + thread_extent_.push_back(op->value); + } + + PrimExpr value = op->value; + if (visit_thread_extent_) { + value = this->VisitExpr(value); + } + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } + return AttrStmt(op->node, op->attr_key, value, body); + } else { + return StmtExprMutator::VisitStmt_(op); + } +} + +Stmt VarUseDefAnalysis::VisitStmt_(const LetStmtNode* op) { + this->HandleDef(op->var.get()); + Stmt body = this->VisitStmt(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && + simplify_let_) { + return body; + } else { + PrimExpr value = this->VisitExpr(op->value); + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } +} + +Stmt VarUseDefAnalysis::VisitStmt_(const ForNode* op) { + this->HandleDef(op->loop_var.get()); + return StmtExprMutator::VisitStmt_(op); +} + +Stmt VarUseDefAnalysis::VisitStmt_(const AllocateNode* op) { + this->HandleDef(op->buffer_var.get()); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt VarUseDefAnalysis::VisitStmt_(const AllocateConstNode* op) { + this->HandleDef(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); +} + +Stmt VarUseDefAnalysis::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +Stmt VarUseDefAnalysis::VisitStmt_(const BufferStoreNode* op) { + VisitBuffer(op->buffer); + return StmtExprMutator::VisitStmt_(op); +} + +PrimExpr VarUseDefAnalysis::VisitExpr_(const LetNode* op) { + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var); + PrimExpr value = this->VisitExpr(op->value); + if (it != let_binding_.end()) { + ICHECK(deep_equal_(it->second->value, value)) + << "Let cannot bind the same var to two different values"; + return GetRef(it->second); + } else { + this->HandleDef(op->var.get()); + let_binding_[op->var] = op; + } + PrimExpr body = this->VisitExpr(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && + simplify_let_) { + return body; + } else { + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return Let(op->var, value, body); + } + } +} + +PrimExpr VarUseDefAnalysis::VisitExpr_(const VarNode* op) { + this->HandleUse(GetRef(op)); + return StmtExprMutator::VisitExpr_(op); +} + +PrimExpr VarUseDefAnalysis::VisitExpr_(const ReduceNode* op) { + for (const auto& iv : op->axis) { + this->HandleDef(iv->var.get()); + } + return StmtExprMutator::VisitExpr_(op); +} + +PrimExpr VarUseDefAnalysis::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +PrimExpr VarUseDefAnalysis::VisitExpr_(const BufferLoadNode* op) { + VisitBuffer(op->buffer); + return StmtExprMutator::VisitExpr_(op); +} + +void VarUseDefAnalysis::VisitBuffer(Buffer buffer) { + this->HandleUse(buffer->data); + auto visit_arr = [&](Array arr) { + for (const auto& element : arr) { + this->VisitExpr(element); + } + }; + + visit_arr(buffer->shape); + visit_arr(buffer->strides); +} + +void VarUseDefAnalysis::HandleDef(const VarNode* v) { + ICHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + ICHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; + use_count_[v] = 0; + def_count_[v] = 1; +} + +void VarUseDefAnalysis::HandleUse(const PrimExpr& v) { + ICHECK(v.as()); + Var var = Downcast(v); + auto it = use_count_.find(var.get()); + if (it != use_count_.end()) { + if (it->second >= 0) { + ++it->second; + } + } else { + undefined_.push_back(var); + use_count_[var.get()] = -1; + } +} + +Array UndefinedVars(const Stmt& stmt, const Array& args) { + VarUseDefAnalysis m; + m.simplify_let_ = false; + for (Var arg : args) { + m.use_count_[arg.get()] = 0; + } + m(stmt); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr) { + VarUseDefAnalysis m; + m.simplify_let_ = false; + m(expr); + return m.undefined_; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de7d38d7d57..91a352b58917 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -40,207 +40,6 @@ namespace tvm { namespace tir { -// use/def analysis, also delete unreferenced lets -class VarUseDefAnalysis : public StmtExprMutator { - public: - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - // thread_extent can appear multiple times - // use the first appearance as def. - if (!use_count_.count(iv->var.get())) { - this->HandleDef(iv->var.get()); - thread_axis_.push_back(iv); - thread_extent_.push_back(op->value); - } - - PrimExpr value = op->value; - if (visit_thread_extent_) { - value = this->VisitExpr(value); - } - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); - } - return AttrStmt(op->node, op->attr_key, value, body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - this->HandleDef(op->var.get()); - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; - } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } - } - } - - Stmt VisitStmt_(const ForNode* op) final { - this->HandleDef(op->loop_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateNode* op) final { - this->HandleDef(op->buffer_var.get()); - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); - use_dyn_shmem_ = true; - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - this->HandleDef(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const StoreNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const LetNode* op) final { - // Weaker SSA condition - // A single var can be binded in multiple lets - // but they have to bind to the same value. - // This is used to allow cases when we reuse a single let - // expression to construct a nested expr. - // (let x = 1 in x + 1) * (let x = 1 in x + 1) - auto it = let_binding_.find(op->var); - PrimExpr value = this->VisitExpr(op->value); - if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, value)) - << "Let cannot bind the same var to two different values"; - return GetRef(it->second); - } else { - this->HandleDef(op->var.get()); - let_binding_[op->var] = op; - } - PrimExpr body = this->VisitExpr(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; - } else { - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return Let(op->var, value, body); - } - } - } - - PrimExpr VisitExpr_(const VarNode* op) final { - this->HandleUse(GetRef(op)); - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const ReduceNode* op) final { - for (const auto& iv : op->axis) { - this->HandleDef(iv->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitExpr_(op); - } - - void VisitBuffer(Buffer buffer) { - this->HandleUse(buffer->data); - auto visit_arr = [&](Array arr) { - for (const auto& element : arr) { - this->VisitExpr(element); - } - }; - - visit_arr(buffer->shape); - visit_arr(buffer->strides); - } - - void HandleDef(const VarNode* v) { - ICHECK(!def_count_.count(v)) << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - ICHECK(!use_count_.count(v)) << "variable " << v->name_hint - << " has been used before definition!"; - use_count_[v] = 0; - def_count_[v] = 1; - } - - void HandleUse(const PrimExpr& v) { - ICHECK(v.as()); - Var var = Downcast(v); - auto it = use_count_.find(var.get()); - if (it != use_count_.end()) { - if (it->second >= 0) { - ++it->second; - } - } else { - undefined_.push_back(var); - use_count_[var.get()] = -1; - } - } - - // The fields are publically readible to - // be accessible to the users. - bool visit_thread_extent_{true}; - bool simplify_let_{true}; - Array undefined_; - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; - std::unordered_map use_count_; - std::unordered_map def_count_; - - private: - ExprDeepEqual deep_equal_; - std::unordered_map let_binding_; -}; - -Array UndefinedVars(const Stmt& stmt, const Array& args) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - for (Var arg : args) { - m.use_count_[arg.get()] = 0; - } - m(stmt); - return m.undefined_; -} - -Array UndefinedVars(const PrimExpr& expr) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - m(expr); - return m.undefined_; -} class HostDeviceSplitter : public StmtMutator { public: From d87c1be30acd30e6e2d3f668c1f8b03f051c66a9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 02:10:09 -0800 Subject: [PATCH 02/11] rename --- include/tvm/tir/analysis.h | 2 +- src/tir/analysis/var_use_def_analysis.cc | 34 ++++++++++++------------ src/tir/transforms/split_host_device.cc | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 45f5f6071849..b18fb0fa2193 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -91,7 +91,7 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod); * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. * \sa UndefinedVars */ -class VarUseDefAnalysis : public StmtExprMutator { +class VarUseDefAnalyzer : public StmtExprMutator { public: // The fields are publically readible to // be accessible to the users. diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 7370a1bff16e..c193a92a57f4 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -29,7 +29,7 @@ namespace tvm { namespace tir { -Stmt VarUseDefAnalysis::VisitStmt_(const AttrStmtNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); @@ -55,7 +55,7 @@ Stmt VarUseDefAnalysis::VisitStmt_(const AttrStmtNode* op) { } } -Stmt VarUseDefAnalysis::VisitStmt_(const LetStmtNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let @@ -72,12 +72,12 @@ Stmt VarUseDefAnalysis::VisitStmt_(const LetStmtNode* op) { } } -Stmt VarUseDefAnalysis::VisitStmt_(const ForNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { this->HandleDef(op->loop_var.get()); return StmtExprMutator::VisitStmt_(op); } -Stmt VarUseDefAnalysis::VisitStmt_(const AllocateNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { this->HandleDef(op->buffer_var.get()); auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { @@ -93,21 +93,21 @@ Stmt VarUseDefAnalysis::VisitStmt_(const AllocateNode* op) { return StmtExprMutator::VisitStmt_(op); } -Stmt VarUseDefAnalysis::VisitStmt_(const AllocateConstNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { this->HandleDef(op->buffer_var.get()); return StmtExprMutator::VisitStmt_(op); } -Stmt VarUseDefAnalysis::VisitStmt_(const StoreNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } -Stmt VarUseDefAnalysis::VisitStmt_(const BufferStoreNode* op) { +Stmt VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { VisitBuffer(op->buffer); return StmtExprMutator::VisitStmt_(op); } -PrimExpr VarUseDefAnalysis::VisitExpr_(const LetNode* op) { +PrimExpr VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { // Weaker SSA condition // A single var can be binded in multiple lets // but they have to bind to the same value. @@ -138,28 +138,28 @@ PrimExpr VarUseDefAnalysis::VisitExpr_(const LetNode* op) { } } -PrimExpr VarUseDefAnalysis::VisitExpr_(const VarNode* op) { +PrimExpr VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { this->HandleUse(GetRef(op)); return StmtExprMutator::VisitExpr_(op); } -PrimExpr VarUseDefAnalysis::VisitExpr_(const ReduceNode* op) { +PrimExpr VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) { for (const auto& iv : op->axis) { this->HandleDef(iv->var.get()); } return StmtExprMutator::VisitExpr_(op); } -PrimExpr VarUseDefAnalysis::VisitExpr_(const LoadNode* op) { +PrimExpr VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } -PrimExpr VarUseDefAnalysis::VisitExpr_(const BufferLoadNode* op) { +PrimExpr VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { VisitBuffer(op->buffer); return StmtExprMutator::VisitExpr_(op); } -void VarUseDefAnalysis::VisitBuffer(Buffer buffer) { +void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) { this->HandleUse(buffer->data); auto visit_arr = [&](Array arr) { for (const auto& element : arr) { @@ -171,7 +171,7 @@ void VarUseDefAnalysis::VisitBuffer(Buffer buffer) { visit_arr(buffer->strides); } -void VarUseDefAnalysis::HandleDef(const VarNode* v) { +void VarUseDefAnalyzer::HandleDef(const VarNode* v) { ICHECK(!def_count_.count(v)) << "variable " << v->name_hint << " has already been defined, the Stmt is not SSA"; ICHECK(!use_count_.count(v)) << "variable " << v->name_hint @@ -180,7 +180,7 @@ void VarUseDefAnalysis::HandleDef(const VarNode* v) { def_count_[v] = 1; } -void VarUseDefAnalysis::HandleUse(const PrimExpr& v) { +void VarUseDefAnalyzer::HandleUse(const PrimExpr& v) { ICHECK(v.as()); Var var = Downcast(v); auto it = use_count_.find(var.get()); @@ -195,7 +195,7 @@ void VarUseDefAnalysis::HandleUse(const PrimExpr& v) { } Array UndefinedVars(const Stmt& stmt, const Array& args) { - VarUseDefAnalysis m; + VarUseDefAnalyzer m; m.simplify_let_ = false; for (Var arg : args) { m.use_count_[arg.get()] = 0; @@ -205,7 +205,7 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { } Array UndefinedVars(const PrimExpr& expr) { - VarUseDefAnalysis m; + VarUseDefAnalyzer m; m.simplify_let_ = false; m(expr); return m.undefined_; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 91a352b58917..962d3825beb2 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -65,7 +65,7 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalysis m; + VarUseDefAnalyzer m; m.visit_thread_extent_ = false; body = m(std::move(body)); From 55c8a4c832393230d9e762fce1ce9b29918c2c6d Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 02:20:58 -0800 Subject: [PATCH 03/11] remove redundancy --- src/tir/transforms/split_host_device.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 962d3825beb2..fc3c088dc619 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -34,8 +34,6 @@ #include -#include "../../runtime/thread_storage_scope.h" -#include "ir_utils.h" namespace tvm { namespace tir { From a3a60b40d99718a108d47cde3a71f7a27ad37eb1 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 4 Mar 2023 09:45:30 -0800 Subject: [PATCH 04/11] refactor --- include/tvm/tir/analysis.h | 86 ++++++---------------- src/tir/analysis/var_use_def_analysis.cc | 42 +++++------ src/tir/analysis/var_use_def_analysis.h | 92 ++++++++++++++++++++++++ src/tir/transforms/split_host_device.cc | 5 +- 4 files changed, 137 insertions(+), 88 deletions(-) create mode 100644 src/tir/analysis/var_use_def_analysis.h diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index b18fb0fa2193..7c7e43d39336 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -1,20 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file + * licensed to the apache software foundation (asf) under one + * or more contributor license agreements. see the notice file * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * regarding copyright ownership. the asf licenses this file + * to you under the apache license, version 2.0 (the + * "license"); you may not use this file except in compliance + * with the license. you may obtain a copy of the license at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/license-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the + * unless required by applicable law or agreed to in writing, + * software distributed under the license is distributed on an + * "as is" basis, without warranties or conditions of any + * kind, either express or implied. see the license for the * specific language governing permissions and limitations - * under the License. + * under the license. */ /*! @@ -87,61 +87,9 @@ TVM_DLL double EstimateTIRFlops(const Stmt& stmt); */ TVM_DLL double EstimateTIRFlops(const IRModule& mod); -/*! - * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. - * \sa UndefinedVars - */ -class VarUseDefAnalyzer : public StmtExprMutator { - public: - // The fields are publically readible to - // be accessible to the users. - bool visit_thread_extent_{true}; - bool simplify_let_{true}; - Array undefined_; - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; - std::unordered_map use_count_; - std::unordered_map def_count_; - - private: - ExprDeepEqual deep_equal_; - std::unordered_map let_binding_; - Stmt VisitStmt_(const AttrStmtNode* op) final; - - Stmt VisitStmt_(const LetStmtNode* op) final; - - Stmt VisitStmt_(const ForNode* op) final; - - Stmt VisitStmt_(const AllocateNode* op) final; - - Stmt VisitStmt_(const AllocateConstNode* op) final; - - Stmt VisitStmt_(const StoreNode* op) final; - - Stmt VisitStmt_(const BufferStoreNode* op) final; - - PrimExpr VisitExpr_(const LetNode* op) final; - - PrimExpr VisitExpr_(const VarNode* op) final; - - PrimExpr VisitExpr_(const ReduceNode* op) final; - - PrimExpr VisitExpr_(const LoadNode* op) final; - - PrimExpr VisitExpr_(const BufferLoadNode* op) final; - - void HandleDef(const VarNode* v); - - void HandleUse(const PrimExpr& v); - - void VisitBuffer(Buffer buffer); -}; - /*! * \brief Find undefined vars in the statement. - * \param stmt The function to be checked. + * \param stmt The statement to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ @@ -154,6 +102,14 @@ TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); */ TVM_DLL Array UndefinedVars(const PrimExpr& expr); +/*! + * \brief Find undefined vars in the expression. + * \param stmt The statement to be checked. + * \param defs The vars that is defined. + * \return Array of undefined vars. + */ +TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); + /*! * \brief Analyze the side effect * \param expr The expression to be checked. diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index c193a92a57f4..19829813aa90 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -21,14 +21,17 @@ * \file var_use_def_analysis.cc * \brief Classes and functions to analyze var defition and usage. */ -#include - -#include "../../runtime/thread_storage_scope.h" -#include "../transforms/ir_utils.h" - +#include "var_use_def_analysis.h" namespace tvm { namespace tir { +VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent) + : visit_thread_extent_(visit_thread_extent) { + for (const Var v : defined_vars) { + use_count_[v.get()] = 0; + } +} + Stmt VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -139,7 +142,7 @@ PrimExpr VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { } PrimExpr VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { - this->HandleUse(GetRef(op)); + this->HandleUse(op); return StmtExprMutator::VisitExpr_(op); } @@ -160,7 +163,7 @@ PrimExpr VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { } void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) { - this->HandleUse(buffer->data); + this->HandleUse(buffer->data.get()); auto visit_arr = [&](Array arr) { for (const auto& element : arr) { this->VisitExpr(element); @@ -180,33 +183,32 @@ void VarUseDefAnalyzer::HandleDef(const VarNode* v) { def_count_[v] = 1; } -void VarUseDefAnalyzer::HandleUse(const PrimExpr& v) { - ICHECK(v.as()); - Var var = Downcast(v); - auto it = use_count_.find(var.get()); +void VarUseDefAnalyzer::HandleUse(const VarNode* v) { + auto it = use_count_.find(v); if (it != use_count_.end()) { if (it->second >= 0) { ++it->second; } } else { - undefined_.push_back(var); - use_count_[var.get()] = -1; + undefined_.push_back(GetRef(v)); + use_count_[v] = -1; } } Array UndefinedVars(const Stmt& stmt, const Array& args) { - VarUseDefAnalyzer m; - m.simplify_let_ = false; - for (Var arg : args) { - m.use_count_[arg.get()] = 0; - } + VarUseDefAnalyzer m(args); m(stmt); return m.undefined_; } Array UndefinedVars(const PrimExpr& expr) { - VarUseDefAnalyzer m; - m.simplify_let_ = false; + VarUseDefAnalyzer m({}); + m(expr); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr, const Array& args) { + VarUseDefAnalyzer m(args); m(expr); return m.undefined_; } diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h new file mode 100644 index 000000000000..8c7bad35550e --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.h @@ -0,0 +1,92 @@ + +/* + * licensed to the apache software foundation (asf) under one + * or more contributor license agreements. see the notice file + * distributed with this work for additional information + * regarding copyright ownership. the asf licenses this file + * to you under the apache license, version 2.0 (the + * "license"); you may not use this file except in compliance + * with the license. you may obtain a copy of the license at + * + * http://www.apache.org/licenses/license-2.0 + * + * unless required by applicable law or agreed to in writing, + * software distributed under the license is distributed on an + * "as is" basis, without warranties or conditions of any + * kind, either express or implied. see the license for the + * specific language governing permissions and limitations + * under the license. + */ + +/*! + * \file tvm/src/tir/analysis/var_use_def_analyzer.h + * \brief Variable definition and usage analysis class. + */ +#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ +#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../transforms/ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. + * \sa UndefinedVars + */ +class VarUseDefAnalyzer : public StmtExprMutator { + public: + explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); + // The fields are publically readible to + // be accessible to the users. + bool visit_thread_extent_{true}; + bool simplify_let_{true}; + Array undefined_; + Array thread_axis_; + Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; + std::unordered_map use_count_; + std::unordered_map def_count_; + + private: + ExprDeepEqual deep_equal_; + std::unordered_map let_binding_; + Stmt VisitStmt_(const AttrStmtNode* op) final; + + Stmt VisitStmt_(const LetStmtNode* op) final; + + Stmt VisitStmt_(const ForNode* op) final; + + Stmt VisitStmt_(const AllocateNode* op) final; + + Stmt VisitStmt_(const AllocateConstNode* op) final; + + Stmt VisitStmt_(const StoreNode* op) final; + + Stmt VisitStmt_(const BufferStoreNode* op) final; + + PrimExpr VisitExpr_(const LetNode* op) final; + + PrimExpr VisitExpr_(const VarNode* op) final; + + PrimExpr VisitExpr_(const ReduceNode* op) final; + + PrimExpr VisitExpr_(const LoadNode* op) final; + + PrimExpr VisitExpr_(const BufferLoadNode* op) final; + + void HandleDef(const VarNode* v); + + void HandleUse(const VarNode* v); + + void VisitBuffer(Buffer buffer); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index fc3c088dc619..09dc6b9ede05 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -34,11 +34,11 @@ #include +#include "../analysis/var_use_def_analysis.h" namespace tvm { namespace tir { - class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) @@ -63,8 +63,7 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalyzer m; - m.visit_thread_extent_ = false; + VarUseDefAnalyzer m({}, false); body = m(std::move(body)); Array params; From 57dcf165b8288c640906dca976494477d8c0ce07 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 4 Mar 2023 10:26:38 -0800 Subject: [PATCH 05/11] split to three classes --- include/tvm/tir/analysis.h | 2 +- src/tir/analysis/var_use_def_analysis.cc | 95 ++++++------------- src/tir/analysis/var_use_def_analysis.h | 34 ++++--- src/tir/transforms/split_host_device.cc | 112 +++++++++++++++++++++-- 4 files changed, 147 insertions(+), 96 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 7c7e43d39336..8867568f215b 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -104,7 +104,7 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); /*! * \brief Find undefined vars in the expression. - * \param stmt The statement to be checked. + * \param expr The expression to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 19829813aa90..a8bbbf04d0ed 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -32,7 +32,7 @@ VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_ } } -Stmt VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); @@ -40,77 +40,48 @@ Stmt VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { // use the first appearance as def. if (!use_count_.count(iv->var.get())) { this->HandleDef(iv->var.get()); - thread_axis_.push_back(iv); - thread_extent_.push_back(op->value); } - PrimExpr value = op->value; if (visit_thread_extent_) { - value = this->VisitExpr(value); + this->VisitExpr(op->value); } - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); - } - return AttrStmt(op->node, op->attr_key, value, body); + + this->VisitStmt(op->body); } else { - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); } } -Stmt VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { this->HandleDef(op->var.get()); - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; - } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } - } + StmtExprVisitor::VisitStmt_(op); } -Stmt VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { this->HandleDef(op->loop_var.get()); - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); } -Stmt VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { this->HandleDef(op->buffer_var.get()); - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); - use_dyn_shmem_ = true; - } - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); } -Stmt VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { this->HandleDef(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); } -Stmt VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } -Stmt VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { +void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { VisitBuffer(op->buffer); - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); } -PrimExpr VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { +void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { // Weaker SSA condition // A single var can be binded in multiple lets // but they have to bind to the same value. @@ -118,48 +89,36 @@ PrimExpr VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { // expression to construct a nested expr. // (let x = 1 in x + 1) * (let x = 1 in x + 1) auto it = let_binding_.find(op->var); - PrimExpr value = this->VisitExpr(op->value); + this->VisitExpr(op->value); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, value)) + ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; - return GetRef(it->second); } else { this->HandleDef(op->var.get()); let_binding_[op->var] = op; } - PrimExpr body = this->VisitExpr(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; - } else { - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return Let(op->var, value, body); - } - } + this->VisitExpr(op->body); } -PrimExpr VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { +void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { this->HandleUse(op); - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } -PrimExpr VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) { +void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) { for (const auto& iv : op->axis) { this->HandleDef(iv->var.get()); } - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } -PrimExpr VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) { +void VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } -PrimExpr VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { +void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { VisitBuffer(op->buffer); - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) { diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 8c7bad35550e..e349c68fac21 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -35,49 +35,47 @@ namespace tir { /*! * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. + * \param defined_vars Variables that have been defined. + * \param visit_thread_extent Whether enters thread extent expressions or not. * \sa UndefinedVars */ -class VarUseDefAnalyzer : public StmtExprMutator { +class VarUseDefAnalyzer : public StmtExprVisitor { public: explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); // The fields are publically readible to // be accessible to the users. bool visit_thread_extent_{true}; - bool simplify_let_{true}; Array undefined_; - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; + std::unordered_map use_count_; std::unordered_map def_count_; private: ExprDeepEqual deep_equal_; std::unordered_map let_binding_; - Stmt VisitStmt_(const AttrStmtNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; - Stmt VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const LetStmtNode* op) final; - Stmt VisitStmt_(const ForNode* op) final; + void VisitStmt_(const ForNode* op) final; - Stmt VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AllocateNode* op) final; - Stmt VisitStmt_(const AllocateConstNode* op) final; + void VisitStmt_(const AllocateConstNode* op) final; - Stmt VisitStmt_(const StoreNode* op) final; + void VisitStmt_(const StoreNode* op) final; - Stmt VisitStmt_(const BufferStoreNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; - PrimExpr VisitExpr_(const LetNode* op) final; + void VisitExpr_(const LetNode* op) final; - PrimExpr VisitExpr_(const VarNode* op) final; + void VisitExpr_(const VarNode* op) final; - PrimExpr VisitExpr_(const ReduceNode* op) final; + void VisitExpr_(const ReduceNode* op) final; - PrimExpr VisitExpr_(const LoadNode* op) final; + void VisitExpr_(const LoadNode* op) final; - PrimExpr VisitExpr_(const BufferLoadNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; void HandleDef(const VarNode* v); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 09dc6b9ede05..5c111feffdda 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -39,6 +39,95 @@ namespace tvm { namespace tir { +/*! + * \brief Visit class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { + public: + Array thread_axis_; + Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; + + private: + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); + thread_axis_.push_back(iv); + thread_extent_.push_back(op->value); + } + + this->VisitExpr(op->value); + this->VisitStmt(op->body); + } else { + StmtVisitor::VisitStmt_(op); + } + } + + void VisitStmt_(const AllocateNode* op) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; + } + StmtVisitor::VisitStmt_(op); + } + + std::unordered_set defined_thread; +}; + +/*! + * \brief Visitor class to remove unrefenced let stmt/expressions. + */ +class UnreferencedLetRemover : public StmtExprMutator { + public: + explicit UnreferencedLetRemover(const std::unordered_map& use_count) + : use_count_(use_count) {} + + private: + Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt body = this->VisitStmt(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { + return body; + } else { + PrimExpr value = this->VisitExpr(op->value); + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + } + + PrimExpr VisitExpr_(const LetNode* op) final { + PrimExpr body = this->VisitExpr(op->body); + PrimExpr value = this->VisitExpr(op->value); + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { + return body; + } else { + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return Let(op->var, value, body); + } + } + } + + const std::unordered_map& use_count_; +}; + class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) @@ -63,15 +152,19 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalyzer m({}, false); - body = m(std::move(body)); + VarUseDefAnalyzer var_use_def({}, false); + var_use_def(body); + DeviceInfoCollector dev_info; + dev_info(body); + UnreferencedLetRemover let_remover(var_use_def.use_count_); + body = let_remover(std::move(body)); Array params; Array arguments; Map remap_vars; // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : m.undefined_) { + for (Var var : var_use_def.undefined_) { if (var.dtype().is_handle()) { // Create a new version of v. auto it = handle_data_type_.find(var.get()); @@ -91,7 +184,7 @@ class HostDeviceSplitter : public StmtMutator { } } // positional arguments - for (Var var : m.undefined_) { + for (Var var : var_use_def.undefined_) { if (!var.dtype().is_handle()) { params.push_back(var); arguments.push_back(var); @@ -101,7 +194,8 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -109,7 +203,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (m.use_dyn_shmem_) { + if (dev_info.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); } @@ -121,11 +215,11 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr arg : arguments) { call_args.push_back(arg); } - for (PrimExpr ext : m.thread_extent_) { + for (PrimExpr ext : dev_info.thread_extent_) { call_args.push_back(ext); } - if (m.use_dyn_shmem_) { - call_args.push_back(m.dyn_shmem_size_); + if (dev_info.use_dyn_shmem_) { + call_args.push_back(dev_info.dyn_shmem_size_); } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } From e00f5de0cf74d9a93f5f5fe8dc5e5ae8ed55b5a1 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 4 Mar 2023 18:13:13 -0800 Subject: [PATCH 06/11] fix asf format --- include/tvm/tir/analysis.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 8867568f215b..8c21db6d2810 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -1,20 +1,20 @@ /* - * licensed to the apache software foundation (asf) under one - * or more contributor license agreements. see the notice file + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information - * regarding copyright ownership. the asf licenses this file - * to you under the apache license, version 2.0 (the - * "license"); you may not use this file except in compliance - * with the license. you may obtain a copy of the license at + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/license-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * unless required by applicable law or agreed to in writing, - * software distributed under the license is distributed on an - * "as is" basis, without warranties or conditions of any - * kind, either express or implied. see the license for the + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the * specific language governing permissions and limitations - * under the license. + * under the License. */ /*! From e60594b1aad4083d96f6cc1577577a7b27b0f2d6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 4 Mar 2023 18:36:19 -0800 Subject: [PATCH 07/11] another asf header --- src/tir/analysis/var_use_def_analysis.h | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index e349c68fac21..5f7567eee7b2 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -1,21 +1,20 @@ - /* - * licensed to the apache software foundation (asf) under one - * or more contributor license agreements. see the notice file + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information - * regarding copyright ownership. the asf licenses this file - * to you under the apache license, version 2.0 (the - * "license"); you may not use this file except in compliance - * with the license. you may obtain a copy of the license at + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/license-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * unless required by applicable law or agreed to in writing, - * software distributed under the license is distributed on an - * "as is" basis, without warranties or conditions of any - * kind, either express or implied. see the license for the + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the * specific language governing permissions and limitations - * under the license. + * under the License. */ /*! From f1da43b5cb422f5fb7d1e2d47c6690168f41bfe6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 5 Mar 2023 06:37:39 -0800 Subject: [PATCH 08/11] lint issue --- src/tir/analysis/var_use_def_analysis.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 5f7567eee7b2..0f1513e3fee6 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -21,8 +21,8 @@ * \file tvm/src/tir/analysis/var_use_def_analyzer.h * \brief Variable definition and usage analysis class. */ -#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ -#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ +#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ +#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ #include @@ -86,4 +86,4 @@ class VarUseDefAnalyzer : public StmtExprVisitor { } // namespace tir } // namespace tvm -#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALZER_H_ +#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ From 38d9e7452915dd4cdec64e6dae21a062d6327bc2 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 5 Mar 2023 07:47:10 -0800 Subject: [PATCH 09/11] include what you use --- src/tir/analysis/var_use_def_analysis.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 0f1513e3fee6..aada4e28660c 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -26,6 +26,8 @@ #include +#include + #include "../../runtime/thread_storage_scope.h" #include "../transforms/ir_utils.h" From 59f47ce521f3a6e5439c580d35f117050e158d14 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 5 Mar 2023 08:10:29 -0800 Subject: [PATCH 10/11] typo --- src/tir/transforms/split_host_device.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 5c111feffdda..2b3e81b272f8 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -40,7 +40,7 @@ namespace tvm { namespace tir { /*! - * \brief Visit class to collect device-side program information. + * \brief Visitor class to collect device-side program information. */ class DeviceInfoCollector : public StmtVisitor { public: @@ -69,7 +69,7 @@ class DeviceInfoCollector : public StmtVisitor { } } - void VisitStmt_(const AllocateNode* op) { + void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; @@ -84,11 +84,13 @@ class DeviceInfoCollector : public StmtVisitor { StmtVisitor::VisitStmt_(op); } + // recording what thread axis have been visited. std::unordered_set defined_thread; }; /*! - * \brief Visitor class to remove unrefenced let stmt/expressions. + * \brief Mutator class to remove unrefenced let stmt/expressions. + * \param use_count The pre-computed variable to use count map. */ class UnreferencedLetRemover : public StmtExprMutator { public: @@ -125,6 +127,7 @@ class UnreferencedLetRemover : public StmtExprMutator { } } + // pre-computed variable to use count map. const std::unordered_map& use_count_; }; @@ -152,11 +155,11 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalyzer var_use_def({}, false); - var_use_def(body); + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + use_def(body); DeviceInfoCollector dev_info; dev_info(body); - UnreferencedLetRemover let_remover(var_use_def.use_count_); + UnreferencedLetRemover let_remover(use_def.use_count_); body = let_remover(std::move(body)); Array params; @@ -164,7 +167,7 @@ class HostDeviceSplitter : public StmtMutator { Map remap_vars; // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : var_use_def.undefined_) { + for (Var var : use_def.undefined_) { if (var.dtype().is_handle()) { // Create a new version of v. auto it = handle_data_type_.find(var.get()); @@ -184,7 +187,7 @@ class HostDeviceSplitter : public StmtMutator { } } // positional arguments - for (Var var : var_use_def.undefined_) { + for (Var var : use_def.undefined_) { if (!var.dtype().is_handle()) { params.push_back(var); arguments.push_back(var); From 096b174582c2f8a91180bbadf79b56c2f19d2463 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 6 Mar 2023 08:01:18 -0800 Subject: [PATCH 11/11] reorganize include --- include/tvm/tir/analysis.h | 1 - src/tir/analysis/var_use_def_analysis.cc | 4 ++-- src/tir/analysis/var_use_def_analysis.h | 6 ++---- src/tir/transforms/split_host_device.cc | 2 ++ 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 8c21db6d2810..3709aeb6f3c7 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -30,7 +30,6 @@ #include #include #include -#include #include diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index a8bbbf04d0ed..7ef8e532a396 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -88,14 +88,14 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { // This is used to allow cases when we reuse a single let // expression to construct a nested expr. // (let x = 1 in x + 1) * (let x = 1 in x + 1) - auto it = let_binding_.find(op->var); + auto it = let_binding_.find(op->var.get()); this->VisitExpr(op->value); if (it != let_binding_.end()) { ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; } else { this->HandleDef(op->var.get()); - let_binding_[op->var] = op; + let_binding_[op->var.get()] = op; } this->VisitExpr(op->body); } diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index aada4e28660c..ad275011d90c 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -25,12 +25,10 @@ #define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ #include +#include #include -#include "../../runtime/thread_storage_scope.h" -#include "../transforms/ir_utils.h" - namespace tvm { namespace tir { @@ -53,7 +51,7 @@ class VarUseDefAnalyzer : public StmtExprVisitor { private: ExprDeepEqual deep_equal_; - std::unordered_map let_binding_; + std::unordered_map let_binding_; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const LetStmtNode* op) final; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2b3e81b272f8..4f411228d262 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -34,7 +34,9 @@ #include +#include "../../runtime/thread_storage_scope.h" #include "../analysis/var_use_def_analysis.h" +#include "ir_utils.h" namespace tvm { namespace tir {