From 3a7605e70b345fb70027259cfa0411906b5cbcea Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 14 Nov 2023 09:23:58 -0600 Subject: [PATCH 1/7] [TIR] Update DeclBuffer nodes when specializing PrimFunc Prior to this commit, a buffer whose parameters (e.g. shape/stride) contained a specialized parameter would not be updated when appearing in a `DeclBuffer` node. This commit updates the `Specialize` function to update buffers that occur in `DeclBuffer` nodes. --- src/tir/ir/specialize.cc | 61 ++++++++++++++++---- tests/python/tir-base/test_tir_specialize.py | 28 +++++++-- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 7ead6e6ae6fb..9ccc5cecd027 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -140,16 +140,33 @@ class PrimFuncSpecializer : public StmtExprMutator { } } + Stmt VisitStmt_(const DeclBufferNode* op) final { + auto new_buf = MutateAllocBuffer(op->buffer); + + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + if (new_buf.same_as(op->buffer)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->buffer = new_buf; + return Stmt(n); + } + } + Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { + + auto new_buf = GetNewBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->buffer = it->second; + n->buffer = new_buf; return Stmt(n); } } @@ -158,12 +175,13 @@ class PrimFuncSpecializer : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { + + auto new_buf = GetNewBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return GetRef(op); } else { auto n = make_object(*op); - n->buffer = it->second; + n->buffer = new_buf; return PrimExpr(n); } } @@ -227,14 +245,33 @@ class PrimFuncSpecializer : public StmtExprMutator { } Buffer MutateAllocBuffer(const Buffer& alloc_buf) { + ICHECK(!buffer_map_.count(alloc_buf)) + << "Multiple points of definition found for buffer " << alloc_buf; + Buffer buf = MutateBuffer(alloc_buf); - if (buf.same_as(alloc_buf)) { - return alloc_buf; - } else { - ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); - buffer_map_[alloc_buf] = buf; - return buf; + buffer_map_[alloc_buf] = buf; + return buf; + } + + Buffer GetNewBuffer(const Buffer& old_buffer) { + if (auto it = buffer_map_.find(old_buffer); it != buffer_map_.end()) { + return it->second; } + + auto mutated = MutateBuffer(old_buffer); + ICHECK(mutated.same_as(old_buffer)) + << "Buffer " << old_buffer << " (shape = " << old_buffer->shape << ")" + << " was used without a declaration, " + << "and would be specialized into " << mutated << " (shape = " << mutated->shape << "). " + << "While usage of an undeclared buffer is currently allowed in TIR, " + << "mutation must occur at the buffer's point of definition " + << "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). " + << "Please add a definition for this buffer, " + << "either in the PrimFunc's buffer_map, " + << "in a tir::Block's alloc_buffer, " + << "or in a DeclBuffer statement."; + + return old_buffer; } BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 508730aacfe2..d98c76d140f7 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -243,10 +243,26 @@ def test_specialize_with_const_folding(): assert_structural_equal_ignore_global_symbol(func, param_in_arith_exprs_n_16) +def test_specialize_decl_buffer(): + """Buffers occurring in a DeclBuffer statement should be updated""" + + @T.prim_func(private=True) + def before(A_data: T.handle("float32"), A_size: T.int32): + A_buf = T.decl_buffer(A_size, "float32", data=A_data) + for i in range(A_size): + A_buf[i] = A_buf[i] * 2.0 + + @T.prim_func(private=True) + def expected(A_data: T.handle("float32")): + A_buf = T.decl_buffer(16, "float32", data=A_data) + for i in range(16): + A_buf[i] = A_buf[i] * 2.0 + + param_map = {before.params[1]: T.int32(16)} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": - test_specialize_nothing() - test_specialize_matmul() - test_specialize_elemwise() - test_specialize_mem_copy() - test_specialize_recursive_load() - test_specialize_with_const_folding() + tvm.testing.main() From eb2a853bdba748f9878103a8af2d43375a3ab937 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 14 Nov 2023 10:06:49 -0600 Subject: [PATCH 2/7] [TIR] Handle specialization that remaps a buffer var --- src/tir/ir/specialize.cc | 7 +++-- tests/python/tir-base/test_tir_specialize.py | 30 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 9ccc5cecd027..8a4ab5edb9af 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -216,17 +216,19 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Buffer MutateBuffer(const Buffer& buffer) { + Var data = VisitExpr(buffer->data).as().value_or(buffer->data); Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); - if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && - buffer->strides.same_as(strides)) { + if (buffer->data.same_as(data) && buffer->elem_offset.same_as(elem_offset) && + buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { return buffer; } else { auto n = make_object(*buffer.get()); + n->data = std::move(data); n->elem_offset = std::move(elem_offset); n->shape = std::move(shape); n->strides = std::move(strides); @@ -348,6 +350,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer << " vs. " << specific_buf->strides.size() << "."; // Updating var mapping using specific_expr + build_var_mapping(specific_buf->data, buf_to_specialize->data); for (size_t i = 0; i < specific_buf->shape.size(); ++i) { build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); } diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index d98c76d140f7..05df7adf4cfb 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -264,5 +264,35 @@ def expected(A_data: T.handle("float32")): tvm.ir.assert_structural_equal(expected, after) +def test_specialize_buffer_var_to_var(): + """A buffer var may be remapped by specialization + + If a buffer variable is replaced by a specialization, then other + buffers using the same buffer var should also be updated. + """ + + @T.prim_func(private=True) + def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")): + A_flat = T.decl_buffer([256], "float32", data=A.data) + B_flat = T.decl_buffer([256], "float32", data=B.data) + for i in range(256): + B_flat[i] = A_flat[i] * 2.0 + + @T.prim_func(private=True) + def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): + B = T.match_buffer(B_handle, [16, 16], "float32", data=A.data) + A_flat = T.decl_buffer([256], "float32", data=A.data) + B_flat = T.decl_buffer([256], "float32", data=A.data) + for i in range(256): + B_flat[i] = A_flat[i] * 2.0 + + A = before.buffer_map[before.params[0]] + B_handle = before.params[1] + param_map = {B_handle: A} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main() From b4841421861b7fcb17636eb31f07dcd215bb46ac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 14 Nov 2023 10:19:16 -0600 Subject: [PATCH 3/7] [TIR] Handle specialization of buffer variable to PrimExpr --- src/tir/ir/specialize.cc | 58 ++++++++++---- tests/python/tir-base/test_tir_specialize.py | 80 ++++++++++++++------ 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 8a4ab5edb9af..d4fc5478a9c5 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -29,6 +29,7 @@ #include +#include "../transforms/ir_utils.h" #include "functor_common.h" namespace tvm { @@ -115,18 +116,18 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block - Array alloc_buffers = op->alloc_buffers.Map( - std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + Array alloc_buffers = + op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); }); // Step.1. Recursively visit block body Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - Array reads = op->reads.Map( - std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); - Array writes = op->writes.Map( - std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array reads = + op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); }); + Array writes = + op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); }); if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && writes.same_as(op->writes)) { @@ -141,19 +142,40 @@ class PrimFuncSpecializer : public StmtExprMutator { } Stmt VisitStmt_(const DeclBufferNode* op) final { - auto new_buf = MutateAllocBuffer(op->buffer); + // Visit the buffer before delegating to StmtExprMutator, so the + // buffer's replacement will be defined before the point of use. + Var old_buffer_var = op->buffer->data; + Buffer new_buf = MutateAllocBuffer(op->buffer); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - if (new_buf.same_as(op->buffer)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->buffer = new_buf; - return Stmt(n); + if (!new_buf.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = new_buf; + } + + // If the buffer variable is begin remapped to an expression, we + // still need a tir::Var to be used as a the buffer variable. + // Therefore, generate a LetStmt that will provide a tir::Var for + // the buffer to use. + // + // This step is only required when a buffer definition is using a + // previously-defined buffer variable, which is therefore eligible + // for specialization. An allocation in the + // `BlockNode::alloc_buffers` defines both the buffer variable and + // the buffer, this check is unnecessary there. In addition, if + // the buffer var has been remapped to another variable, it has already + // been handled as part of the buffer mutation. + Var new_buffer_var = node->buffer->data; + Stmt stmt = std::move(node); + + if (new_buffer_var.same_as(old_buffer_var)) { + auto remapped_data = VisitExpr(old_buffer_var); + if (!remapped_data.same_as(old_buffer_var)) { + stmt = LetStmt(old_buffer_var, remapped_data, stmt); + } } + + return stmt; } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -216,7 +238,11 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Buffer MutateBuffer(const Buffer& buffer) { + // For the data variable, only Var-to-Var remapping can be handled + // in MutateBuffer. See the DeclBuffer visitor for the handling + // of Var-to-PrimExpr remapping. Var data = VisitExpr(buffer->data).as().value_or(buffer->data); + Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 05df7adf4cfb..f695b8522594 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -169,28 +169,6 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 B[vi, vj] = A[vi, vj] -@T.prim_func -def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: - n = T.int32() - A = T.match_buffer(a, [n // 8, 8], "int32") - B = T.match_buffer(b, [n], "int32") - for i in range(n - 1): - with T.block(): - vi = T.axis.S(n - 1, i) - B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 - - -@T.prim_func -def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: - n = T.int32() - A = T.match_buffer(a, [2, 8], "int32") - B = T.match_buffer(b, [16], "int32") - for i in range(15): - with T.block(): - vi = T.axis.S(15, i) - B[vi] = A[vi // 8, vi % 8] + 714 - - def test_specialize_nothing(): func = matmul.specialize({}) assert func.same_as(matmul) # Pointer the same @@ -238,9 +216,28 @@ def test_specialize_recursive_load(): def test_specialize_with_const_folding(): - b = param_in_arith_exprs.params[1] - func = param_in_arith_exprs.specialize({b: tvm.tir.decl_buffer([16])}) - assert_structural_equal_ignore_global_symbol(func, param_in_arith_exprs_n_16) + @T.prim_func + def before(a: T.handle, b: T.handle): + n = T.int32() + A = T.match_buffer(a, [n // 8, 8], "int32") + B = T.match_buffer(b, [n], "int32") + for i in range(n - 1): + with T.block(): + vi = T.axis.S(n - 1, i) + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, [2, 8], "int32") + B = T.match_buffer(b, [16], "int32") + for i in range(15): + with T.block(): + vi = T.axis.S(15, i) + B[vi] = A[vi // 8, vi % 8] + 714 + + b = before.params[1] + after = before.specialize({b: tvm.tir.decl_buffer([16], dtype="int32")}) + assert_structural_equal_ignore_global_symbol(expected, after) def test_specialize_decl_buffer(): @@ -294,5 +291,38 @@ def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): tvm.ir.assert_structural_equal(expected, after) +def test_specialize_buffer_var_to_expr(): + """Handle specialization of buffer var + + The `tir::Buffer::data` field must be an explicit `tir::Var`, and + cannot be replaced with a `tir::PrimExpr` of type + `DataType::Handle()`. However, these substitutions are useful + when lowering. If these occur, a binding of the `tir::Var` is + included in the specialized function. + """ + + @T.prim_func(private=True) + def before(A_data: T.handle("float32"), B_data: T.handle("float32")): + A_buf = T.decl_buffer(32, "float32", data=A_data) + B_buf = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B_buf[i] = A_buf[i] * 2.0 + + @T.prim_func(private=True) + def expected(A_data: T.handle("float32")): + A_buf = T.decl_buffer(32, "float32", data=A_data) + B_data: T.Ptr[T.float32] = T.address_of(A_buf[16]) + B_buf = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B_buf[i] = A_buf[i] * 2.0 + + B_data = before.params[1] + A_buf = before.body.buffer + param_map = {B_data: tvm.tir.address_of(A_buf[16])} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main() From f49b1f8e5154ab3f30ac43941990ab495aae036a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 14 Nov 2023 10:53:56 -0600 Subject: [PATCH 4/7] [TIR][Transform] Implement InlinePrivateFunctions The functionality to express a call from one `PrimFunc` to another was introduced in https://github.com/apache/tvm/pull/14889. While this was initially planned to be supported at codegen for all targets (see https://github.com/apache/tvm/pull/15835), this resulted in breakage on some backends (see https://github.com/apache/tvm/pull/16033). After discussion, the plan was changed to support TIR inlining, which would enable the same high-level functionality in TIR without requiring immediate low-level support across all codegens. This commit implements and tests a new IRModule transform `InlinePrivateFunctions`, which can be used as part of lowering in a follow-up commit. Because this is initially implemented for use quite late in the lowering flow, many constructs are not currently supported. The current implementation has the following restrictions. * `tir::Block` nodes may not occur in the inlined function. Because a subroutine may be called multiple times, inlining of a subroutine that contains `tir::Block` would result in non-unique names. Support of subroutines with `tir::Block` instances will require de-duplication of block names. * The subroutine's callsite must occur within a `tir::Evaluate` block. Because inlining a subroutine inserts the `tir::Stmt` body at the point of use, replacement must occur in a context where a `tir::Stmt` can be returned. Support of subroutines that are called within an expression (e.g. Replacing `func` in `Buf[0] = func(1) + func(2)`) would require hoisting preprocessing done in the subroutine to the parent `tir::Stmt`. * The subroutine may only accept primitive arguments, and must have an empty `buffer_map`. Support of subroutines that are called with `tir::Buffer` or `tir::BufferRegion` arguments would require a way to represent these arguments at the callsite, and substitution of the buffer into the callee. If these unsupported constructs are used, then the inlining of those functions is skipped. This commit includes unit tests for these unsupported constructs, to validate that `InlinePrivateFunctions` produces well-formed output even when they are present. --- include/tvm/tir/transform.h | 7 + python/tvm/tir/transform/transform.py | 11 + .../transforms/inline_private_functions.cc | 273 ++++++++++++++++++ .../test_tir_inline_private_functions.py | 253 ++++++++++++++++ 4 files changed, 544 insertions(+) create mode 100644 src/tir/transforms/inline_private_functions.cc create mode 100644 tests/python/tir-transform/test_tir_inline_private_functions.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a1697d807db9..76826fdf7c5a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize(); */ TVM_DLL Pass FP8StorageLegalize(); +/*! + * \brief Inline calls to private functions + * + * \return The pass. + */ +TVM_DLL Pass InlinePrivateFunctions(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a46b2d10373f..42c9aecd18e7 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -230,6 +230,17 @@ def StorageRewrite(): return _ffi_api.StorageRewrite() # type: ignore +def InlinePrivateFunctions(): + """Inline calls to private functions + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InlinePrivateFunctions() # type: ignore + + def PointerValueTypeRewrite(): """ Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc new file mode 100644 index 000000000000..a47c852067fa --- /dev/null +++ b/src/tir/transforms/inline_private_functions.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inline_private_functions.cc + * \brief Inline private functions to their callsite + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace transform { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +PMap> CollectCallMap(const IRModule& mod) { + struct Visitor : StmtExprVisitor { + GlobalVar current; + PMap> caller_lookup; + + void VisitExpr_(const CallNode* op) { + if (auto gvar = op->op.as()) { + caller_lookup[gvar.value()].insert(current); + } + StmtExprVisitor::VisitExpr_(op); + } + } visitor; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + visitor.current = gvar; + visitor(prim_func->body); + } + } + + return visitor.caller_lookup; +} + +PSet CollectRecursiveFunctions(const IRModule& mod) { + // Collect all direct callers + auto call_map = CollectCallMap(mod); + + // Propagate to find all indirect callers + while (true) { + bool made_change = false; + for (const auto& [callee, callers] : call_map) { + for (const auto& caller : callers) { + if (auto it = call_map.find(caller); it != call_map.end()) { + PSet& indirect_callers = it->second; + + auto res = indirect_callers.insert(callee); + made_change = made_change || res.second; + } + } + } + if (!made_change) { + break; + } + } + + // Filter all GlobalVars that can be called by themselves, either + // directly or indirectly. + PSet recursive_funcs; + for (const auto& [caller, callees] : call_map) { + if (callees.count(caller)) { + recursive_funcs.insert(caller); + } + } + return recursive_funcs; +} + +Map CollectInlinablePrimFuncs(const IRModule& mod) { + auto recursive_functions = CollectRecursiveFunctions(mod); + + Map output; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + + // Only inline private functions. Externally-exposed functions + // must be preserved so to avoid breaking callsites outside of + // the IRModule. + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + + // We do not currently implement any analysis for termination of + // a function. If a recursive function requires runtime checks + // in order to terminate, we would keep inlining until the + // recursive visits segfault. + bool is_recursive = recursive_functions.count(gvar); + + // We do not currently support inlining of functions that accept + // buffer arguments. + bool has_buffer_arguments = prim_func->buffer_map.size(); + + // We do not currently support inlining of schedulable TIR + // functions. To support this use case, repeated names in + // `tir::Block` nodes resulting from multiple calls to the same + // inlined function will need to be de-duplicated. + bool has_block_node = prim_func->body.as(); + + if (!is_exposed && !is_recursive && !has_buffer_arguments && !has_block_node) { + output.Set(gvar, prim_func); + } + } + } + + return output; +} + +class PrimFuncInliner : StmtExprMutator { + public: + explicit PrimFuncInliner(Map inlinable_funcs) + : inlinable_funcs_(inlinable_funcs) { + for (const auto& [gvar, callee] : inlinable_funcs_) { + removable_funcs_.insert(gvar); + } + } + + PrimFunc VisitFunc(PrimFunc func) { + current_target_ = func->GetAttr(tvm::attr::kTarget); + auto new_body = VisitStmt(func->body); + current_target_ = NullOpt; + + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + + return func; + } + + PSet GetRemovableFunctions() const { return removable_funcs_; } + + private: + Stmt VisitStmt_(const EvaluateNode* eval) override { + if (auto call = eval->value.as()) { + if (auto gvar = call->op.as()) { + if (auto opt_callee = inlinable_funcs_.Get(gvar.value())) { + auto callee = opt_callee.value(); + + bool is_same_target = [&]() -> bool { + auto callee_target = callee->GetAttr(tvm::attr::kTarget); + if (current_target_ && callee_target) { + return callee_target.value()->str() == current_target_.value()->str(); + } else { + return true; + } + }(); + + if (is_same_target) { + Stmt inlined = InlineArguments(gvar.value(), callee, call->args); + return VisitStmt(inlined); + } + } + } + } + + return StmtExprMutator::VisitStmt_(eval); + } + + PrimExpr VisitExpr_(const CallNode* call) override { + // Any callee that hasn't been inlined at this point must be kept + // in the output IRModule. + if (auto gvar = call->op.as()) { + removable_funcs_.erase(gvar.value()); + } + return StmtExprMutator::VisitExpr_(call); + } + + Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array& args) const { + CHECK_EQ(callee->params.size(), args.size()) + << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" + << callee->params << "), but is called with " << args.size() << " arguments (" << args + << ")"; + + ICHECK(callee->buffer_map.empty()) + << "Inlining of PrimFuncs with buffer arguments is not yet supported, " + << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; + + Map param_map; + for (size_t i = 0; i < callee->params.size(); i++) { + param_map.Set(callee->params[i], args[i]); + } + + callee = Specialize(callee, param_map); + + return callee->body; + } + + // Map from GlobalVar to PrimFuncs which may be inlined. + Map inlinable_funcs_; + + /* \brief Set of callees that may be removed + * + * Some constructs may not be inlined (e.g. if the call site occurs + * outside of an Evaluate node). For these cases, the output + * IRModule must still contain the callee. + */ + PSet removable_funcs_; + + Optional current_target_ = NullOpt; +}; + +} // namespace + +Pass InlinePrivateFunctions() { + auto pass_func = [](IRModule mod, PassContext ctx) { + auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod); + + if (inlinable_prim_funcs.empty()) { + // Early bail-out if the module has no inlinable PrimFuncs. + return mod; + } + + PrimFuncInliner mutator(std::move(inlinable_prim_funcs)); + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto updated = mutator.VisitFunc(opt.value()); + if (!updated.same_as(base_func)) { + updates->Add(gvar, updated); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + for (const auto& gvar : mutator.GetRemovableFunctions()) { + write_ptr->Remove(gvar); + } + mod = ConvertSSA()(mod); + } + + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py new file mode 100644 index 000000000000..2edf74ebfb3d --- /dev/null +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.testing +from tvm.script import ir as I, tir as T + + +class BaseTestCase: + def test_well_formed(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.tir.analysis.verify_well_formed(After) + + def test_produces_expected(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.ir.assert_structural_equal(self.Expected, After) + + +class TestSimple(BaseTestCase): + """Simple case directly acting on PrimFunc""" + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + A_view_data: T.handle("float32") = T.address_of(A[i, 0]) + Aview = T.decl_buffer([16, 16], "float32", data=A_view_data) + B_view_data: T.handle("float32") = T.address_of(B[i, 0]) + Bview = T.decl_buffer([16], "float32", data=B_view_data) + for j in range(16): + Bview[j] = 0.0 + for k in range(16): + Bview[j] = Bview[j] + Aview[j, k] + + +class TestRetainCrossFunctionSubroutines(BaseTestCase): + """Do not inline functions that cross device boundaries + + When lowering TIR, calls for which the callsite and callee have + different targets are used at some stages, before being further + lowered to explicit device kernel launches. Since inlining the + function would remove this cross-device information, + InlinePrivateSubroutines should not inline these cases. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + T.func_attr({"target": T.target("llvm")}) + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + Expected = Before + + +class TestRetainRecursiveSubroutines(BaseTestCase): + """Do not inline recursive functions + + To avoid potentially infinite loops at compile-time, disable + inlining of recursive functions. If inlining of these functions + would be useful, this restriction may be relaxed with improved + analysis of the subroutine. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine(T.address_of(A[0]), 16) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), A_size: T.int32): + A = T.decl_buffer(A_size, "float32", data=A_data) + A[1] = A[0] + A[1] + + if A_size > 1: + Before.subroutine(T.address_of(A[1]), A_size - 1) + + Expected = Before + + +class TestDeduplicateBlockName(BaseTestCase): + """Block names must be de-duplicated after inlining""" + + @pytest.mark.xfail(reason="Inlining of schedulable TIR not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([2, 16], "float32"), B: T.Buffer([2, 16], "float32")): + Before.subroutine(T.address_of(A[0, 0]), T.address_of(B[0, 0])) + Before.subroutine(T.address_of(A[1, 0]), T.address_of(B[1, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + with T.block("scalar_mul"): + B[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: + A_1 = T.decl_buffer(16, "float32", data=A_data_1) + B_data_1: T.handle("float32") = T.address_of(B[0, 0]) + B_1 = T.decl_buffer(16, "float32", data=B_data_1) + for i in range(16): + with T.block("scalar_mul_1"): + B_1[i] = A_1[i] * 2.0 + + with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: + A_2 = T.decl_buffer(16, "float32", data=A_data_2) + B_data_2: T.handle("float32") = T.address_of(B[1, 0]) + B_2 = T.decl_buffer(16, "float32", data=B_data_2) + for i in range(16): + with T.block("scalar_mul_2"): + B_2[i] = A_2[i] * 2.0 + + +class TestInlineCallOccurringInExpression(BaseTestCase): + """Inline a Call node that is used in a function + + The current implementation only replaces `tir.Call` instances that + occur in a `tir.Evaluate` context. This is the primary use case, + used in destination-passing style. + + This unit test is marked as xfail. If/when the implementation + supports inlining of function calls occurring as part of an + expression, the annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tir.Evaluate is not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = Before.subroutine(i) + + @T.prim_func(private=True) + def subroutine(i: T.int32) -> T.float32: + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + T.ret(retval) + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + A[i] = retval + + +class TestInlineFunctionWithBufferArguments(BaseTestCase): + """Inline a function that accepts buffer arguments + + The current implementation does not support this usage. This unit + test is provided to display a possible user interaction, and is + marked with `@pytest.mark.xfail`. If/when the implementation + supports inlining of function calls with buffer arguments, the + annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs with buffer arguments") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine( + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(*A.shape, dtype="handle"), + 0, + len(A.shape), + 0.0, + A.elem_offset, + dtype="handle", + ) + ) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + +if __name__ == "__main__": + tvm.testing.main() From 4d74f521b5dfe0437d6375cb7a339757ba9aad79 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 20:56:38 -0600 Subject: [PATCH 5/7] Updates based on review comments --- src/tir/ir/specialize.cc | 2 +- .../transforms/inline_private_functions.cc | 115 +++++++++++------- 2 files changed, 72 insertions(+), 45 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index d4fc5478a9c5..5964f0293299 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -153,7 +153,7 @@ class PrimFuncSpecializer : public StmtExprMutator { node.CopyOnWrite()->buffer = new_buf; } - // If the buffer variable is begin remapped to an expression, we + // If the buffer variable is being remapped to an expression, we // still need a tir::Var to be used as a the buffer variable. // Therefore, generate a LetStmt that will provide a tir::Var for // the buffer to use. diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index a47c852067fa..cc33ba9f86c2 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -97,6 +97,36 @@ PSet CollectRecursiveFunctions(const IRModule& mod) { return recursive_funcs; } +bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, + const PSet& recursive_functions) { + // Only inline private functions. Externally-exposed functions + // must be preserved so to avoid breaking callsites outside of + // the IRModule. + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed) return false; + + // We do not currently implement any analysis for termination of + // a function. If a recursive function requires runtime checks + // in order to terminate, we would keep inlining until the + // recursive visits segfault. + bool is_recursive = recursive_functions.count(gvar); + if (is_recursive) return false; + + // We do not currently support inlining of functions that accept + // buffer arguments. + bool has_buffer_arguments = prim_func->buffer_map.size(); + if (has_buffer_arguments) return false; + + // We do not currently support inlining of schedulable TIR + // functions. To support this use case, repeated names in + // `tir::Block` nodes resulting from multiple calls to the same + // inlined function will need to be de-duplicated. + bool has_block_node = prim_func->body.as(); + if (has_block_node) return false; + + return true; +} + Map CollectInlinablePrimFuncs(const IRModule& mod) { auto recursive_functions = CollectRecursiveFunctions(mod); @@ -104,29 +134,7 @@ Map CollectInlinablePrimFuncs(const IRModule& mod) { for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto prim_func = opt.value(); - - // Only inline private functions. Externally-exposed functions - // must be preserved so to avoid breaking callsites outside of - // the IRModule. - bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); - - // We do not currently implement any analysis for termination of - // a function. If a recursive function requires runtime checks - // in order to terminate, we would keep inlining until the - // recursive visits segfault. - bool is_recursive = recursive_functions.count(gvar); - - // We do not currently support inlining of functions that accept - // buffer arguments. - bool has_buffer_arguments = prim_func->buffer_map.size(); - - // We do not currently support inlining of schedulable TIR - // functions. To support this use case, repeated names in - // `tir::Block` nodes resulting from multiple calls to the same - // inlined function will need to be de-duplicated. - bool has_block_node = prim_func->body.as(); - - if (!is_exposed && !is_recursive && !has_buffer_arguments && !has_block_node) { + if (IsInlinablePrimFunc(gvar, prim_func, recursive_functions)) { output.Set(gvar, prim_func); } } @@ -160,32 +168,51 @@ class PrimFuncInliner : StmtExprMutator { private: Stmt VisitStmt_(const EvaluateNode* eval) override { - if (auto call = eval->value.as()) { - if (auto gvar = call->op.as()) { - if (auto opt_callee = inlinable_funcs_.Get(gvar.value())) { - auto callee = opt_callee.value(); - - bool is_same_target = [&]() -> bool { - auto callee_target = callee->GetAttr(tvm::attr::kTarget); - if (current_target_ && callee_target) { - return callee_target.value()->str() == current_target_.value()->str(); - } else { - return true; - } - }(); - - if (is_same_target) { - Stmt inlined = InlineArguments(gvar.value(), callee, call->args); - return VisitStmt(inlined); - } - } - } + if (auto inlined = GetInlinedFunction(eval)) { + return inlined.value(); + } else { + return StmtExprMutator::VisitStmt_(eval); } + } + + Optional GetInlinedFunction(const EvaluateNode* eval) { + auto call = eval->value.as(); + if (!call) return NullOpt; + + auto gvar = call->op.as(); + if (!gvar) return NullOpt; + + auto opt_callee = inlinable_funcs_.Get(gvar.value()); + if (!opt_callee) return NullOpt; + auto callee = opt_callee.value(); + + bool is_same_target = [&]() -> bool { + auto callee_target = callee->GetAttr(tvm::attr::kTarget); + if (current_target_ && callee_target) { + return callee_target.value()->str() == current_target_.value()->str(); + } else { + return true; + } + }(); + if (!is_same_target) return NullOpt; - return StmtExprMutator::VisitStmt_(eval); + Stmt inlined = InlineArguments(gvar.value(), callee, call->args); + return VisitStmt(inlined); } PrimExpr VisitExpr_(const CallNode* call) override { + // Because the current implementation inlines a subroutine inserts + // the `tir::Stmt` body at the point of use, replacement must + // occur in a context where a `tir::Stmt` can be returned. Support + // of subroutines that are called within an expression + // (e.g. Replacing func in `Buf[0] = func(1) + func(2)`) would + // require hoisting preprocessing done in the subroutine to the + // parent `tir::Stmt`. + // + // See `TestInlineCallOccurringInExpression` in + // `test_tir_inline_private_functions.py` for a test of this + // behavior, currently marked with `pytest.mark.xfail`. + // // Any callee that hasn't been inlined at this point must be kept // in the output IRModule. if (auto gvar = call->op.as()) { From 4ca4c1121ad17e79ee3f90ffff41873a097f086a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 29 Dec 2023 10:11:43 -0600 Subject: [PATCH 6/7] ci bump From efe7bbd779694ed3c5b2f1ad5d5c815a18215b1c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Jan 2024 14:26:08 -0600 Subject: [PATCH 7/7] CI bump