Skip to content
7 changes: 7 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class BufferNode : public Object {
*/
PrimExpr ElemOffset(Array<PrimExpr> index) const;

/*! \brief Return number of elements in the buffer
*
* If the size of the buffer isn't constant, or if the size would
* overflow a 32-bit signed integer, return 0.
*/
int32_t NumElements() const;

static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ class AllocateNode : public StmtNode {
Var buffer_var;
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
Array<PrimExpr> extents;
/*! \brief The extent of the buffer. */
PrimExpr extent;
/*! \brief Only allocate buffer when condition is satisfied. */
PrimExpr condition;
/*! \brief The body to be executed. */
Expand All @@ -532,7 +532,7 @@ class AllocateNode : public StmtNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("extent", &extent);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
Expand All @@ -541,14 +541,14 @@ class AllocateNode : public StmtNode {

bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(condition, other->condition) &&
equal(extent, other->extent) && equal(condition, other->condition) &&
equal(body, other->body) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(buffer_var);
hash_reduce(dtype);
hash_reduce(extents);
hash_reduce(extent);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(annotations);
Expand All @@ -559,14 +559,14 @@ class AllocateNode : public StmtNode {
* Otherwise return 0.
* \return The result.
*/
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
int32_t constant_allocation_size() const { return constant_allocation_size(extent); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \param extent The extent of the buffer.
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
TVM_DLL static int32_t constant_allocation_size(const PrimExpr& extent);

static constexpr const char* _type_key = "tir.Allocate";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
Expand All @@ -578,8 +578,8 @@ class AllocateNode : public StmtNode {
*/
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
TVM_DLL Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr condition, Stmt body,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def ReplaceOperators():
pointer_to_producer = {}
pointer_to_consumer = {}
replace_output_pointer = {}
pointer_to_extents = {}
pointer_to_extent = {}

def _resolve_pointers(stmt):
"""This pass determines information about the pointers present in the IR.
Expand All @@ -75,7 +75,7 @@ def _get_loads(stmt):
loads.append(stmt.buffer_var)

if isinstance(stmt, tvm.tir.Allocate):
pointer_to_extents[stmt.buffer_var] = stmt.extents
pointer_to_extent[stmt.buffer_var] = stmt.extent
if isinstance(stmt.body[0], tvm.tir.AttrStmt):
if stmt.body[0].attr_key == "pragma_op":
pointer_to_producer[stmt.buffer_var] = stmt.body[0]
Expand Down Expand Up @@ -160,7 +160,7 @@ def _replace_pointers(stmt):
# If the pointer doesn't have an extent registered to it,
# this means the pointer is to a Buffer. In this case, we
# just want to delete the memory scope attribute
if replace_pointer not in pointer_to_extents:
if replace_pointer not in pointer_to_extent:
return stmt.body
# Otherwise, rewrite the memory scope attribute with the new pointer
return tvm.tir.AttrStmt(
Expand All @@ -174,12 +174,12 @@ def _replace_pointers(stmt):
# If the pointer doesn't have an extent registered to it,
# this means the pointer is to a Buffer. In this case, we
# just want to delete the allocation statement
if replace_pointer not in pointer_to_extents:
if replace_pointer not in pointer_to_extent:
return stmt.body
# Otherwise, rewrite the allocation statement with the new pointer
# and the new extent
replace_type = replace_pointer.type_annotation.element_type.dtype
replace_extents = pointer_to_extents[replace_pointer]
replace_extents = pointer_to_extent[replace_pointer]
return tvm.tir.Allocate(
replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body
)
Expand Down Expand Up @@ -404,10 +404,11 @@ def _visit_rewrite(stmt):
if pointer_to_buffer[allocate_pointer] in rewrite_buffer:
new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]]
new_pointer = rewrite_pointer[allocate_pointer]
assert len(new_buffer.shape) == 1
return tvm.tir.Allocate(
new_pointer,
new_buffer.dtype,
new_buffer.shape,
new_buffer.shape[0],
stmt.condition,
stmt.body,
stmt.span,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def populate_allocate_buffer_info(stmt):
allocate = stmt
buffer_info[allocate.buffer_var] = BufferInfo(
None,
allocate.extents,
[allocate.extent],
allocate.dtype,
BufferType.scratch,
)
Expand Down
Loading