Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
*/
TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);

/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
* The LCA may be a For loop or a Block.
* \param func The PrimFunc to be detected.
* \return The Map from buffer to the LCA of all access to it.
*/
TVM_DLL Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name

from typing import Dict
from . import _ffi_api
from ..function import PrimFunc
from .. import Buffer, Stmt


def expr_deep_equal(lhs, rhs):
Expand Down Expand Up @@ -129,3 +131,21 @@ def get_block_access_region(block, buffer_var_map):
- third: opaque regions
"""
return _ffi_api.get_block_access_region(block, buffer_var_map)


def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
"""Detect the lowest common ancestor(LCA) of buffer access, including both high-level
access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
The LCA may be a For loop or a Block.

Parameters
----------
func: tvm.tir.PrimFunc
The function to be detected.

Returns
-------
result : Dict[Buffer, Stmt]
Map from buffer to the LCA of all access to it.
"""
return _ffi_api.detect_buffer_access_lca(func) # pylint: disable=no-member
173 changes: 173 additions & 0 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/analysis/buffer_access_lca_detector.cc
* \brief Detect the lowest common ancestor(LCA) of buffer access
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include "../../support/arena.h"

namespace tvm {
namespace tir {

/*!
* \brief Detect the lowest common ancestor(LCA) position of Buffer access.
* \note Only consider BlockNode and ForNode to be the LCA nodes.
*/
class LCADetector : public StmtExprVisitor {
public:
static Map<Buffer, Stmt> Detect(const PrimFunc& func) {
LCADetector detector;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
}
detector(func->body);
// Prepare the return
Map<Buffer, Stmt> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt));
}
return buffer_lca;
}

private:
/*!
* \brief The AST node information for querying LCA.
* \note Only BlockNode and ForNode are considered, since they are the only statements whose
* body can be a SeqStmt (the LCA of buffer access) in TensorIR.
*/
struct ScopeInfo {
// The parent scope info
const ScopeInfo* parent_scope_info;
// The parent scope stmt node
const StmtNode* stmt;
// The scope depth in the AST
int depth;
ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth)
: parent_scope_info(parent_info), stmt(stmt), depth(depth) {}
};

void VisitStmt_(const ForNode* op) final {
int n = ancestor_scopes_.size();
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void VisitStmt_(const BlockNode* op) final {
int n = ancestor_scopes_.size();
for (const Buffer& buf : op->alloc_buffers) {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferRealizeNode* op) final {
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
}

// Works for Load/Store and opaque access.
void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); }

// Explict to visit buffer data in Load and Store node.
void VisitExpr_(const LoadNode* op) final {
ExprVisitor::VisitExpr_(op);
VisitBufferVar(op->buffer_var.get());
}

void VisitStmt_(const StoreNode* op) final {
StmtVisitor::VisitStmt_(op);
VisitBufferVar(op->buffer_var.get());
}

void VisitBufferVar(const VarNode* op) {
auto it = buffer_var_map_.find(op);
if (it != buffer_var_map_.end()) {
UpdateBufferLCA(it->second);
}
}

void UpdateBufferLCA(const BufferNode* buffer) {
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
ICHECK(lhs || rhs);
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
while (lhs->parent_scope_info != nullptr && //
rhs->parent_scope_info != nullptr && //
lhs != rhs) {
if (lhs->depth == rhs->depth) {
lhs = lhs->parent_scope_info;
rhs = rhs->parent_scope_info;
} else if (lhs->depth < rhs->depth) {
rhs = rhs->parent_scope_info;
} else {
lhs = lhs->parent_scope_info;
}
}
if (lhs->parent_scope_info == nullptr) {
return lhs;
}
if (rhs->parent_scope_info == nullptr) {
return rhs;
}
ICHECK(lhs == rhs);
return lhs;
}

/*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */
std::vector<const ScopeInfo*> ancestor_scopes_ = {nullptr};
/*! \brief The map from Buffer to its LCA ForNode/BlockNode. */
std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {};
/*! \brief The map from Buffer data to the Buffer. */
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};

Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); }

TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
} // namespace tir
} // namespace tvm
107 changes: 107 additions & 0 deletions tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm.script import ty


@tvm.script.tir
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128), "float32")
B = tir.match_buffer(b, (128, 128), "float32")
C = tir.alloc_buffer((128, 128), "float32")
D = tir.alloc_buffer((128, 128), "float32")
with tir.block([128, 128]) as [i, j]:
A[i, j] = tir.float32(0)
with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
with tir.init():
for ii, jj in tir.grid(4, 4):
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
for ii, jj in tir.grid(4, 4):
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]


@tvm.script.tir
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
B = tir.match_buffer(b, [16, 16], "float32")
C = tir.match_buffer(c, [16, 16], "float32")

with tir.block([]):
tir.reads([])
tir.writes(B[0:16, 0:16])
A = tir.allocate([256], "float32", "global")
for i, j in tir.grid(16, 16):
tir.store(A, i * 16 + j, 1)
for i in range(0, 16):
for j in range(0, 16):
tir.evaluate(tir.load("float32", A, i * 16 + j))
for j in range(0, 16):
tir.evaluate(
tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle")
)

for i, j in tir.grid(16, 16):
with tir.block([16, 16]) as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
C[vi, vj] = B[vi, vj]


def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
C, D = func.body.block.alloc_buffers
lca = tir.analysis.detect_buffer_access_lca(func)

# LCA of Buffer A is root
root_block = func.body.block
assert lca[A] == func.body.block

# LCA of Buffer B is reduction block
reduce_block = root_block.body[1].body.body.body.block
assert lca[B] == reduce_block

# LCA of Buffer C is the second loop kk
loop_jj = reduce_block.body.body
assert lca[C] == loop_jj

# LCA of Buffer D is loop jj
loop_kk = loop_jj.body[1]
assert lca[D] == loop_kk


def test_opaque_access():
func = buffer_opaque_access
B, C = [func.buffer_map[x] for x in func.params]
lca = tir.analysis.detect_buffer_access_lca(func)

# Cannot detect buffer A since it is define by low-level Allocate

# LCA of Buffer B is root
root_block = func.body.block
assert lca[B] == func.body.block

# LCA of Buffer C is the correspond block
assert lca[C] == root_block.body[1].body.body.block


if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()