Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,20 +521,28 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(condition, other->condition) &&
equal(body, other->body);
equal(body, other->body) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -543,6 +551,7 @@ class AllocateNode : public StmtNode {
hash_reduce(extents);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(annotations);
}

/*!
Expand Down Expand Up @@ -570,7 +579,8 @@ class AllocateNode : public StmtNode {
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span = Span());
Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,20 @@ def get_optional_vars(node, context):

@register
class Allocate(WithScopeHandler):
"""With scope handler T.allocate(extents, dtype, scope, condition)"""
"""With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""

def __init__(self):
def allocate(extents, dtype, scope, condition=True, span=None):
def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)
return tvm.tir.Allocate(
self.buffer_var, dtype, extents, condition, self.body, span=span
self.buffer_var,
dtype,
extents,
condition,
self.body,
annotations=annotations,
span=span,
)

super().__init__(allocate, concise_scope=True, def_symbol=True)
Expand All @@ -137,7 +143,9 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
def setup_buffer_var(
extents, dtype, scope, condition=True, annotations=None, span: Span = None
):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,25 @@ class Allocate(Stmt):
body : Stmt
The body statement.

annotations: Optional[Mapping[str, Object]]
Additional annotation hints

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
if annotations is None:
annotations = dict()
self.__init_handle_by_constructor__(
_ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore
_ffi_api.Allocate, # type: ignore
buffer_var,
dtype,
extents,
condition,
body,
annotations,
span,
)


Expand Down
12 changes: 10 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,16 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
<< Print(op->extents) << "), storage_scope = " << scope;
doc << "allocate(" << Print(op->buffer_var) << ", ";
doc << PrintDType(op->dtype) << ", ";
doc << Print(op->extents) << "), storage_scope = " << scope;
if (!op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
if (!is_one(op->condition)) {
doc << " if " << Print(op->condition);
}
Expand Down
10 changes: 10 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
Expand All @@ -777,6 +782,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ")" << Doc::NewLine() << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
Expand Down
7 changes: 4 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
Stmt body, Map<String, ObjectRef> annotations, Span span) {
CHECK(IsPointerType(buffer_var->type_annotation, dtype))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
Expand All @@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
Expand All @@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {

TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
return Allocate(buffer_var, type, extents, condition, body, span);
Stmt body, Map<String, ObjectRef> annotations, Span span) {
return Allocate(buffer_var, type, extents, condition, body, annotations, span);
});

TVM_REGISTER_NODE_TYPE(AllocateNode);
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,38 @@ def test_block_blockrealize():
assert output.find("with init()") != -1


def test_tir_allocate():
dtype = "int8"
storage_scope = "global"
ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
a = te.var("buffer", ptype)
allocate = tvm.tir.Allocate(
buffer_var=a,
dtype=dtype,
extents=[2, 2],
condition=tvm.get_global_func("tir.const_true")(dtype, None),
body=tvm.tir.Evaluate(2 + 1),
annotations={
"attr1": "foo",
"attr2": "bar",
},
)
assert allocate.buffer_var == a
assert allocate.dtype == "int8"
assert list(allocate.extents) == [2, 2]
assert allocate.annotations["attr1"] == "foo"
assert allocate.annotations["attr2"] == "bar"

# make sure we can print using TIRTextPrinter
func = tvm.tir.PrimFunc([], allocate)
output = func.astext()
assert (
output.find(
'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
)
!= -1
)


if __name__ == "__main__":
pytest.main([__file__])
27 changes: 27 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,5 +3059,32 @@ def test_while_loop():
tvm.ir.assert_structural_equal(while_loop, rt_func)


# fmt: off
@T.prim_func
def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
# body
tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"})
for ax0_ax1_fused_4 in T.serial(0, 56):
for ax2_4 in T.serial(0, 56):
for ax3_init in T.serial(0, 64):
T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True)
for ax0_ax1_fused_5 in T.serial(0, 56):
for ax2_5, ax3_3 in T.grid(56, 64):
T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
# fmt: on


def test_primfunc_with_allocate_annotations():
func = primfunc_with_allocate_annotations
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))