Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize();
*/
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Inline calls to private functions
*
* \return The pass.
*/
TVM_DLL Pass InlinePrivateFunctions();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ def StorageRewrite():
return _ffi_api.StorageRewrite() # type: ignore


def InlinePrivateFunctions():
"""Inline calls to private functions

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InlinePrivateFunctions() # type: ignore


def PointerValueTypeRewrite():
"""
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use
Expand Down
106 changes: 86 additions & 20 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <functional>

#include "../transforms/ir_utils.h"
#include "functor_common.h"

namespace tvm {
Expand Down Expand Up @@ -115,18 +116,18 @@ class PrimFuncSpecializer : public StmtExprMutator {
private:
Stmt VisitStmt_(const BlockNode* op) final {
// Step.0. Define buffer mappings which is allocated inside the block
Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
Array<Buffer> alloc_buffers =
op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner this way :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I always need to pause when encountering std::placeholders, and try to replace it when reasonable to do so.


// Step.1. Recursively visit block body
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

Array<BufferRegion> reads = op->reads.Map(
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<BufferRegion> writes = op->writes.Map(
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<BufferRegion> reads =
op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); });
Array<BufferRegion> writes =
op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); });

if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
writes.same_as(op->writes)) {
Expand All @@ -140,16 +141,54 @@ class PrimFuncSpecializer : public StmtExprMutator {
}
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
// Visit the buffer before delegating to StmtExprMutator, so the
// buffer's replacement will be defined before the point of use.
Var old_buffer_var = op->buffer->data;
Buffer new_buf = MutateAllocBuffer(op->buffer);

auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

if (!new_buf.same_as(node->buffer)) {
node.CopyOnWrite()->buffer = new_buf;
}

// If the buffer variable is being remapped to an expression, we
// still need a tir::Var to be used as a the buffer variable.
// Therefore, generate a LetStmt that will provide a tir::Var for
// the buffer to use.
//
// This step is only required when a buffer definition is using a
// previously-defined buffer variable, which is therefore eligible
// for specialization. An allocation in the
// `BlockNode::alloc_buffers` defines both the buffer variable and
// the buffer, this check is unnecessary there. In addition, if
// the buffer var has been remapped to another variable, it has already
// been handled as part of the buffer mutation.
Var new_buffer_var = node->buffer->data;
Stmt stmt = std::move(node);

if (new_buffer_var.same_as(old_buffer_var)) {
auto remapped_data = VisitExpr(old_buffer_var);
if (!remapped_data.same_as(old_buffer_var)) {
stmt = LetStmt(old_buffer_var, remapped_data, stmt);
}
}

return stmt;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {

auto new_buf = GetNewBuffer(op->buffer);
if (new_buf.same_as(op->buffer)) {
return GetRef<BufferStore>(op);
} else {
auto n = CopyOnWrite(op);
n->buffer = it->second;
n->buffer = new_buf;
return Stmt(n);
}
}
Expand All @@ -158,12 +197,13 @@ class PrimFuncSpecializer : public StmtExprMutator {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {

auto new_buf = GetNewBuffer(op->buffer);
if (new_buf.same_as(op->buffer)) {
return GetRef<BufferLoad>(op);
} else {
auto n = make_object<BufferLoadNode>(*op);
n->buffer = it->second;
n->buffer = new_buf;
return PrimExpr(n);
}
}
Expand Down Expand Up @@ -198,17 +238,23 @@ class PrimFuncSpecializer : public StmtExprMutator {

private:
Buffer MutateBuffer(const Buffer& buffer) {
// For the data variable, only Var-to-Var remapping can be handled
// in MutateBuffer. See the DeclBuffer visitor for the handling
// of Var-to-PrimExpr remapping.
Var data = VisitExpr(buffer->data).as<Var>().value_or(buffer->data);

Array<PrimExpr> shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); });
Array<PrimExpr> strides =
buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); });

PrimExpr elem_offset = VisitExpr(buffer->elem_offset);

if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
buffer->strides.same_as(strides)) {
if (buffer->data.same_as(data) && buffer->elem_offset.same_as(elem_offset) &&
buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) {
return buffer;
} else {
auto n = make_object<BufferNode>(*buffer.get());
n->data = std::move(data);
n->elem_offset = std::move(elem_offset);
n->shape = std::move(shape);
n->strides = std::move(strides);
Expand All @@ -227,14 +273,33 @@ class PrimFuncSpecializer : public StmtExprMutator {
}

Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
ICHECK(!buffer_map_.count(alloc_buf))
<< "Multiple points of definition found for buffer " << alloc_buf;

Buffer buf = MutateBuffer(alloc_buf);
if (buf.same_as(alloc_buf)) {
return alloc_buf;
} else {
ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
buffer_map_[alloc_buf] = buf;
return buf;
buffer_map_[alloc_buf] = buf;
return buf;
}

Buffer GetNewBuffer(const Buffer& old_buffer) {
if (auto it = buffer_map_.find(old_buffer); it != buffer_map_.end()) {
return it->second;
}

auto mutated = MutateBuffer(old_buffer);
ICHECK(mutated.same_as(old_buffer))
<< "Buffer " << old_buffer << " (shape = " << old_buffer->shape << ")"
<< " was used without a declaration, "
<< "and would be specialized into " << mutated << " (shape = " << mutated->shape << "). "
<< "While usage of an undeclared buffer is currently allowed in TIR, "
<< "mutation must occur at the buffer's point of definition "
<< "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). "
<< "Please add a definition for this buffer, "
<< "either in the PrimFunc's buffer_map, "
<< "in a tir::Block's alloc_buffer, "
<< "or in a DeclBuffer statement.";

return old_buffer;
}

BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
Expand Down Expand Up @@ -311,6 +376,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer
<< " vs. " << specific_buf->strides.size() << ".";

// Updating var mapping using specific_expr
build_var_mapping(specific_buf->data, buf_to_specialize->data);
for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
}
Expand Down
Loading