diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a1697d807db9..76826fdf7c5a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize(); */ TVM_DLL Pass FP8StorageLegalize(); +/*! + * \brief Inline calls to private functions + * + * \return The pass. + */ +TVM_DLL Pass InlinePrivateFunctions(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a46b2d10373f..42c9aecd18e7 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -230,6 +230,17 @@ def StorageRewrite(): return _ffi_api.StorageRewrite() # type: ignore +def InlinePrivateFunctions(): + """Inline calls to private functions + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InlinePrivateFunctions() # type: ignore + + def PointerValueTypeRewrite(): """ Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 7ead6e6ae6fb..5964f0293299 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -29,6 +29,7 @@ #include +#include "../transforms/ir_utils.h" #include "functor_common.h" namespace tvm { @@ -115,18 +116,18 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block - Array alloc_buffers = op->alloc_buffers.Map( - std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + Array alloc_buffers = + op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); }); // Step.1. Recursively visit block body Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - Array reads = op->reads.Map( - std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); - Array writes = op->writes.Map( - std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array reads = + op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); }); + Array writes = + op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); }); if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && writes.same_as(op->writes)) { @@ -140,16 +141,54 @@ class PrimFuncSpecializer : public StmtExprMutator { } } + Stmt VisitStmt_(const DeclBufferNode* op) final { + // Visit the buffer before delegating to StmtExprMutator, so the + // buffer's replacement will be defined before the point of use. + Var old_buffer_var = op->buffer->data; + Buffer new_buf = MutateAllocBuffer(op->buffer); + + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (!new_buf.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buf; + } + + // If the buffer variable is being remapped to an expression, we + // still need a tir::Var to be used as a the buffer variable. + // Therefore, generate a LetStmt that will provide a tir::Var for + // the buffer to use. + // + // This step is only required when a buffer definition is using a + // previously-defined buffer variable, which is therefore eligible + // for specialization. An allocation in the + // `BlockNode::alloc_buffers` defines both the buffer variable and + // the buffer, this check is unnecessary there. In addition, if + // the buffer var has been remapped to another variable, it has already + // been handled as part of the buffer mutation. + Var new_buffer_var = node->buffer->data; + Stmt stmt = std::move(node); + + if (new_buffer_var.same_as(old_buffer_var)) { + auto remapped_data = VisitExpr(old_buffer_var); + if (!remapped_data.same_as(old_buffer_var)) { + stmt = LetStmt(old_buffer_var, remapped_data, stmt); + } + } + + return stmt; + } + Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { + + auto new_buf = GetNewBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->buffer = it->second; + n->buffer = new_buf; return Stmt(n); } } @@ -158,12 +197,13 @@ class PrimFuncSpecializer : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { + + auto new_buf = GetNewBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return GetRef(op); } else { auto n = make_object(*op); - n->buffer = it->second; + n->buffer = new_buf; return PrimExpr(n); } } @@ -198,17 +238,23 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Buffer MutateBuffer(const Buffer& buffer) { + // For the data variable, only Var-to-Var remapping can be handled + // in MutateBuffer. See the DeclBuffer visitor for the handling + // of Var-to-PrimExpr remapping. + Var data = VisitExpr(buffer->data).as().value_or(buffer->data); + Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); - if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && - buffer->strides.same_as(strides)) { + if (buffer->data.same_as(data) && buffer->elem_offset.same_as(elem_offset) && + buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { return buffer; } else { auto n = make_object(*buffer.get()); + n->data = std::move(data); n->elem_offset = std::move(elem_offset); n->shape = std::move(shape); n->strides = std::move(strides); @@ -227,14 +273,33 @@ class PrimFuncSpecializer : public StmtExprMutator { } Buffer MutateAllocBuffer(const Buffer& alloc_buf) { + ICHECK(!buffer_map_.count(alloc_buf)) + << "Multiple points of definition found for buffer " << alloc_buf; + Buffer buf = MutateBuffer(alloc_buf); - if (buf.same_as(alloc_buf)) { - return alloc_buf; - } else { - ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); - buffer_map_[alloc_buf] = buf; - return buf; + buffer_map_[alloc_buf] = buf; + return buf; + } + + Buffer GetNewBuffer(const Buffer& old_buffer) { + if (auto it = buffer_map_.find(old_buffer); it != buffer_map_.end()) { + return it->second; } + + auto mutated = MutateBuffer(old_buffer); + ICHECK(mutated.same_as(old_buffer)) + << "Buffer " << old_buffer << " (shape = " << old_buffer->shape << ")" + << " was used without a declaration, " + << "and would be specialized into " << mutated << " (shape = " << mutated->shape << "). " + << "While usage of an undeclared buffer is currently allowed in TIR, " + << "mutation must occur at the buffer's point of definition " + << "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). " + << "Please add a definition for this buffer, " + << "either in the PrimFunc's buffer_map, " + << "in a tir::Block's alloc_buffer, " + << "or in a DeclBuffer statement."; + + return old_buffer; } BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { @@ -311,6 +376,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer << " vs. " << specific_buf->strides.size() << "."; // Updating var mapping using specific_expr + build_var_mapping(specific_buf->data, buf_to_specialize->data); for (size_t i = 0; i < specific_buf->shape.size(); ++i) { build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); } diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc new file mode 100644 index 000000000000..cc33ba9f86c2 --- /dev/null +++ b/src/tir/transforms/inline_private_functions.cc @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inline_private_functions.cc + * \brief Inline private functions to their callsite + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace transform { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +PMap> CollectCallMap(const IRModule& mod) { + struct Visitor : StmtExprVisitor { + GlobalVar current; + PMap> caller_lookup; + + void VisitExpr_(const CallNode* op) { + if (auto gvar = op->op.as()) { + caller_lookup[gvar.value()].insert(current); + } + StmtExprVisitor::VisitExpr_(op); + } + } visitor; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + visitor.current = gvar; + visitor(prim_func->body); + } + } + + return visitor.caller_lookup; +} + +PSet CollectRecursiveFunctions(const IRModule& mod) { + // Collect all direct callers + auto call_map = CollectCallMap(mod); + + // Propagate to find all indirect callers + while (true) { + bool made_change = false; + for (const auto& [callee, callers] : call_map) { + for (const auto& caller : callers) { + if (auto it = call_map.find(caller); it != call_map.end()) { + PSet& indirect_callers = it->second; + + auto res = indirect_callers.insert(callee); + made_change = made_change || res.second; + } + } + } + if (!made_change) { + break; + } + } + + // Filter all GlobalVars that can be called by themselves, either + // directly or indirectly. + PSet recursive_funcs; + for (const auto& [caller, callees] : call_map) { + if (callees.count(caller)) { + recursive_funcs.insert(caller); + } + } + return recursive_funcs; +} + +bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, + const PSet& recursive_functions) { + // Only inline private functions. Externally-exposed functions + // must be preserved so to avoid breaking callsites outside of + // the IRModule. + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed) return false; + + // We do not currently implement any analysis for termination of + // a function. If a recursive function requires runtime checks + // in order to terminate, we would keep inlining until the + // recursive visits segfault. + bool is_recursive = recursive_functions.count(gvar); + if (is_recursive) return false; + + // We do not currently support inlining of functions that accept + // buffer arguments. + bool has_buffer_arguments = prim_func->buffer_map.size(); + if (has_buffer_arguments) return false; + + // We do not currently support inlining of schedulable TIR + // functions. To support this use case, repeated names in + // `tir::Block` nodes resulting from multiple calls to the same + // inlined function will need to be de-duplicated. + bool has_block_node = prim_func->body.as(); + if (has_block_node) return false; + + return true; +} + +Map CollectInlinablePrimFuncs(const IRModule& mod) { + auto recursive_functions = CollectRecursiveFunctions(mod); + + Map output; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + if (IsInlinablePrimFunc(gvar, prim_func, recursive_functions)) { + output.Set(gvar, prim_func); + } + } + } + + return output; +} + +class PrimFuncInliner : StmtExprMutator { + public: + explicit PrimFuncInliner(Map inlinable_funcs) + : inlinable_funcs_(inlinable_funcs) { + for (const auto& [gvar, callee] : inlinable_funcs_) { + removable_funcs_.insert(gvar); + } + } + + PrimFunc VisitFunc(PrimFunc func) { + current_target_ = func->GetAttr(tvm::attr::kTarget); + auto new_body = VisitStmt(func->body); + current_target_ = NullOpt; + + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + + return func; + } + + PSet GetRemovableFunctions() const { return removable_funcs_; } + + private: + Stmt VisitStmt_(const EvaluateNode* eval) override { + if (auto inlined = GetInlinedFunction(eval)) { + return inlined.value(); + } else { + return StmtExprMutator::VisitStmt_(eval); + } + } + + Optional GetInlinedFunction(const EvaluateNode* eval) { + auto call = eval->value.as(); + if (!call) return NullOpt; + + auto gvar = call->op.as(); + if (!gvar) return NullOpt; + + auto opt_callee = inlinable_funcs_.Get(gvar.value()); + if (!opt_callee) return NullOpt; + auto callee = opt_callee.value(); + + bool is_same_target = [&]() -> bool { + auto callee_target = callee->GetAttr(tvm::attr::kTarget); + if (current_target_ && callee_target) { + return callee_target.value()->str() == current_target_.value()->str(); + } else { + return true; + } + }(); + if (!is_same_target) return NullOpt; + + Stmt inlined = InlineArguments(gvar.value(), callee, call->args); + return VisitStmt(inlined); + } + + PrimExpr VisitExpr_(const CallNode* call) override { + // Because the current implementation inlines a subroutine inserts + // the `tir::Stmt` body at the point of use, replacement must + // occur in a context where a `tir::Stmt` can be returned. Support + // of subroutines that are called within an expression + // (e.g. Replacing func in `Buf[0] = func(1) + func(2)`) would + // require hoisting preprocessing done in the subroutine to the + // parent `tir::Stmt`. + // + // See `TestInlineCallOccurringInExpression` in + // `test_tir_inline_private_functions.py` for a test of this + // behavior, currently marked with `pytest.mark.xfail`. + // + // Any callee that hasn't been inlined at this point must be kept + // in the output IRModule. + if (auto gvar = call->op.as()) { + removable_funcs_.erase(gvar.value()); + } + return StmtExprMutator::VisitExpr_(call); + } + + Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array& args) const { + CHECK_EQ(callee->params.size(), args.size()) + << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" + << callee->params << "), but is called with " << args.size() << " arguments (" << args + << ")"; + + ICHECK(callee->buffer_map.empty()) + << "Inlining of PrimFuncs with buffer arguments is not yet supported, " + << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; + + Map param_map; + for (size_t i = 0; i < callee->params.size(); i++) { + param_map.Set(callee->params[i], args[i]); + } + + callee = Specialize(callee, param_map); + + return callee->body; + } + + // Map from GlobalVar to PrimFuncs which may be inlined. + Map inlinable_funcs_; + + /* \brief Set of callees that may be removed + * + * Some constructs may not be inlined (e.g. if the call site occurs + * outside of an Evaluate node). For these cases, the output + * IRModule must still contain the callee. + */ + PSet removable_funcs_; + + Optional current_target_ = NullOpt; +}; + +} // namespace + +Pass InlinePrivateFunctions() { + auto pass_func = [](IRModule mod, PassContext ctx) { + auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod); + + if (inlinable_prim_funcs.empty()) { + // Early bail-out if the module has no inlinable PrimFuncs. + return mod; + } + + PrimFuncInliner mutator(std::move(inlinable_prim_funcs)); + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto updated = mutator.VisitFunc(opt.value()); + if (!updated.same_as(base_func)) { + updates->Add(gvar, updated); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + for (const auto& gvar : mutator.GetRemovableFunctions()) { + write_ptr->Remove(gvar); + } + mod = ConvertSSA()(mod); + } + + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 508730aacfe2..f695b8522594 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -169,28 +169,6 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 B[vi, vj] = A[vi, vj] -@T.prim_func -def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: - n = T.int32() - A = T.match_buffer(a, [n // 8, 8], "int32") - B = T.match_buffer(b, [n], "int32") - for i in range(n - 1): - with T.block(): - vi = T.axis.S(n - 1, i) - B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 - - -@T.prim_func -def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: - n = T.int32() - A = T.match_buffer(a, [2, 8], "int32") - B = T.match_buffer(b, [16], "int32") - for i in range(15): - with T.block(): - vi = T.axis.S(15, i) - B[vi] = A[vi // 8, vi % 8] + 714 - - def test_specialize_nothing(): func = matmul.specialize({}) assert func.same_as(matmul) # Pointer the same @@ -238,15 +216,113 @@ def test_specialize_recursive_load(): def test_specialize_with_const_folding(): - b = param_in_arith_exprs.params[1] - func = param_in_arith_exprs.specialize({b: tvm.tir.decl_buffer([16])}) - assert_structural_equal_ignore_global_symbol(func, param_in_arith_exprs_n_16) + @T.prim_func + def before(a: T.handle, b: T.handle): + n = T.int32() + A = T.match_buffer(a, [n // 8, 8], "int32") + B = T.match_buffer(b, [n], "int32") + for i in range(n - 1): + with T.block(): + vi = T.axis.S(n - 1, i) + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, [2, 8], "int32") + B = T.match_buffer(b, [16], "int32") + for i in range(15): + with T.block(): + vi = T.axis.S(15, i) + B[vi] = A[vi // 8, vi % 8] + 714 + + b = before.params[1] + after = before.specialize({b: tvm.tir.decl_buffer([16], dtype="int32")}) + assert_structural_equal_ignore_global_symbol(expected, after) + + +def test_specialize_decl_buffer(): + """Buffers occurring in a DeclBuffer statement should be updated""" + + @T.prim_func(private=True) + def before(A_data: T.handle("float32"), A_size: T.int32): + A_buf = T.decl_buffer(A_size, "float32", data=A_data) + for i in range(A_size): + A_buf[i] = A_buf[i] * 2.0 + + @T.prim_func(private=True) + def expected(A_data: T.handle("float32")): + A_buf = T.decl_buffer(16, "float32", data=A_data) + for i in range(16): + A_buf[i] = A_buf[i] * 2.0 + + param_map = {before.params[1]: T.int32(16)} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + + +def test_specialize_buffer_var_to_var(): + """A buffer var may be remapped by specialization + + If a buffer variable is replaced by a specialization, then other + buffers using the same buffer var should also be updated. + """ + + @T.prim_func(private=True) + def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")): + A_flat = T.decl_buffer([256], "float32", data=A.data) + B_flat = T.decl_buffer([256], "float32", data=B.data) + for i in range(256): + B_flat[i] = A_flat[i] * 2.0 + + @T.prim_func(private=True) + def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): + B = T.match_buffer(B_handle, [16, 16], "float32", data=A.data) + A_flat = T.decl_buffer([256], "float32", data=A.data) + B_flat = T.decl_buffer([256], "float32", data=A.data) + for i in range(256): + B_flat[i] = A_flat[i] * 2.0 + + A = before.buffer_map[before.params[0]] + B_handle = before.params[1] + param_map = {B_handle: A} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + + +def test_specialize_buffer_var_to_expr(): + """Handle specialization of buffer var + + The `tir::Buffer::data` field must be an explicit `tir::Var`, and + cannot be replaced with a `tir::PrimExpr` of type + `DataType::Handle()`. However, these substitutions are useful + when lowering. If these occur, a binding of the `tir::Var` is + included in the specialized function. + """ + + @T.prim_func(private=True) + def before(A_data: T.handle("float32"), B_data: T.handle("float32")): + A_buf = T.decl_buffer(32, "float32", data=A_data) + B_buf = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B_buf[i] = A_buf[i] * 2.0 + + @T.prim_func(private=True) + def expected(A_data: T.handle("float32")): + A_buf = T.decl_buffer(32, "float32", data=A_data) + B_data: T.Ptr[T.float32] = T.address_of(A_buf[16]) + B_buf = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B_buf[i] = A_buf[i] * 2.0 + + B_data = before.params[1] + A_buf = before.body.buffer + param_map = {B_data: tvm.tir.address_of(A_buf[16])} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) if __name__ == "__main__": - test_specialize_nothing() - test_specialize_matmul() - test_specialize_elemwise() - test_specialize_mem_copy() - test_specialize_recursive_load() - test_specialize_with_const_folding() + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py new file mode 100644 index 000000000000..2edf74ebfb3d --- /dev/null +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.testing +from tvm.script import ir as I, tir as T + + +class BaseTestCase: + def test_well_formed(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.tir.analysis.verify_well_formed(After) + + def test_produces_expected(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.ir.assert_structural_equal(self.Expected, After) + + +class TestSimple(BaseTestCase): + """Simple case directly acting on PrimFunc""" + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + A_view_data: T.handle("float32") = T.address_of(A[i, 0]) + Aview = T.decl_buffer([16, 16], "float32", data=A_view_data) + B_view_data: T.handle("float32") = T.address_of(B[i, 0]) + Bview = T.decl_buffer([16], "float32", data=B_view_data) + for j in range(16): + Bview[j] = 0.0 + for k in range(16): + Bview[j] = Bview[j] + Aview[j, k] + + +class TestRetainCrossFunctionSubroutines(BaseTestCase): + """Do not inline functions that cross device boundaries + + When lowering TIR, calls for which the callsite and callee have + different targets are used at some stages, before being further + lowered to explicit device kernel launches. Since inlining the + function would remove this cross-device information, + InlinePrivateSubroutines should not inline these cases. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + T.func_attr({"target": T.target("llvm")}) + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + Expected = Before + + +class TestRetainRecursiveSubroutines(BaseTestCase): + """Do not inline recursive functions + + To avoid potentially infinite loops at compile-time, disable + inlining of recursive functions. If inlining of these functions + would be useful, this restriction may be relaxed with improved + analysis of the subroutine. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine(T.address_of(A[0]), 16) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), A_size: T.int32): + A = T.decl_buffer(A_size, "float32", data=A_data) + A[1] = A[0] + A[1] + + if A_size > 1: + Before.subroutine(T.address_of(A[1]), A_size - 1) + + Expected = Before + + +class TestDeduplicateBlockName(BaseTestCase): + """Block names must be de-duplicated after inlining""" + + @pytest.mark.xfail(reason="Inlining of schedulable TIR not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([2, 16], "float32"), B: T.Buffer([2, 16], "float32")): + Before.subroutine(T.address_of(A[0, 0]), T.address_of(B[0, 0])) + Before.subroutine(T.address_of(A[1, 0]), T.address_of(B[1, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + with T.block("scalar_mul"): + B[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: + A_1 = T.decl_buffer(16, "float32", data=A_data_1) + B_data_1: T.handle("float32") = T.address_of(B[0, 0]) + B_1 = T.decl_buffer(16, "float32", data=B_data_1) + for i in range(16): + with T.block("scalar_mul_1"): + B_1[i] = A_1[i] * 2.0 + + with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: + A_2 = T.decl_buffer(16, "float32", data=A_data_2) + B_data_2: T.handle("float32") = T.address_of(B[1, 0]) + B_2 = T.decl_buffer(16, "float32", data=B_data_2) + for i in range(16): + with T.block("scalar_mul_2"): + B_2[i] = A_2[i] * 2.0 + + +class TestInlineCallOccurringInExpression(BaseTestCase): + """Inline a Call node that is used in a function + + The current implementation only replaces `tir.Call` instances that + occur in a `tir.Evaluate` context. This is the primary use case, + used in destination-passing style. + + This unit test is marked as xfail. If/when the implementation + supports inlining of function calls occurring as part of an + expression, the annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tir.Evaluate is not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = Before.subroutine(i) + + @T.prim_func(private=True) + def subroutine(i: T.int32) -> T.float32: + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + T.ret(retval) + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + A[i] = retval + + +class TestInlineFunctionWithBufferArguments(BaseTestCase): + """Inline a function that accepts buffer arguments + + The current implementation does not support this usage. This unit + test is provided to display a possible user interaction, and is + marked with `@pytest.mark.xfail`. If/when the implementation + supports inlining of function calls with buffer arguments, the + annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs with buffer arguments") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine( + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(*A.shape, dtype="handle"), + 0, + len(A.shape), + 0.0, + A.elem_offset, + dtype="handle", + ) + ) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + +if __name__ == "__main__": + tvm.testing.main()