diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f04209d0b061..c84eda466570 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -129,6 +129,13 @@ class BufferNode : public Object { */ PrimExpr ElemOffset(Array index) const; + /*! \brief Return number of elements in the buffer + * + * If the size of the buffer isn't constant, or if the size would + * overflow a 32-bit signed integer, return 0. + */ + int32_t NumElements() const; + static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5cd860b8e929..480ceebbf315 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -515,8 +515,8 @@ class AllocateNode : public StmtNode { Var buffer_var; /*! \brief The type of the buffer. */ DataType dtype; - /*! \brief The extents of the buffer. */ - Array extents; + /*! \brief The extent of the buffer. */ + PrimExpr extent; /*! \brief Only allocate buffer when condition is satisfied. */ PrimExpr condition; /*! \brief The body to be executed. */ @@ -532,7 +532,7 @@ class AllocateNode : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("dtype", &dtype); - v->Visit("extents", &extents); + v->Visit("extent", &extent); v->Visit("condition", &condition); v->Visit("body", &body); v->Visit("annotations", &annotations); @@ -541,14 +541,14 @@ class AllocateNode : public StmtNode { bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && - equal(extents, other->extents) && equal(condition, other->condition) && + equal(extent, other->extent) && equal(condition, other->condition) && equal(body, other->body) && equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(buffer_var); hash_reduce(dtype); - hash_reduce(extents); + hash_reduce(extent); hash_reduce(condition); hash_reduce(body); hash_reduce(annotations); @@ -559,14 +559,14 @@ class AllocateNode : public StmtNode { * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { return constant_allocation_size(extents); } + int32_t constant_allocation_size() const { return constant_allocation_size(extent); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. - * \param extents The extents of the buffer. + * \param extent The extent of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size(const Array& extents); + TVM_DLL static int32_t constant_allocation_size(const PrimExpr& extent); static constexpr const char* _type_key = "tir.Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); @@ -578,8 +578,8 @@ class AllocateNode : public StmtNode { */ class Allocate : public Stmt { public: - TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations = Map(), + TVM_DLL Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr condition, Stmt body, + Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 8bb410e986c7..5199d8c37579 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -58,7 +58,7 @@ def ReplaceOperators(): pointer_to_producer = {} pointer_to_consumer = {} replace_output_pointer = {} - pointer_to_extents = {} + pointer_to_extent = {} def _resolve_pointers(stmt): """This pass determines information about the pointers present in the IR. @@ -75,7 +75,7 @@ def _get_loads(stmt): loads.append(stmt.buffer_var) if isinstance(stmt, tvm.tir.Allocate): - pointer_to_extents[stmt.buffer_var] = stmt.extents + pointer_to_extent[stmt.buffer_var] = stmt.extent if isinstance(stmt.body[0], tvm.tir.AttrStmt): if stmt.body[0].attr_key == "pragma_op": pointer_to_producer[stmt.buffer_var] = stmt.body[0] @@ -160,7 +160,7 @@ def _replace_pointers(stmt): # If the pointer doesn't have an extent registered to it, # this means the pointer is to a Buffer. In this case, we # just want to delete the memory scope attribute - if replace_pointer not in pointer_to_extents: + if replace_pointer not in pointer_to_extent: return stmt.body # Otherwise, rewrite the memory scope attribute with the new pointer return tvm.tir.AttrStmt( @@ -174,12 +174,12 @@ def _replace_pointers(stmt): # If the pointer doesn't have an extent registered to it, # this means the pointer is to a Buffer. In this case, we # just want to delete the allocation statement - if replace_pointer not in pointer_to_extents: + if replace_pointer not in pointer_to_extent: return stmt.body # Otherwise, rewrite the allocation statement with the new pointer # and the new extent replace_type = replace_pointer.type_annotation.element_type.dtype - replace_extents = pointer_to_extents[replace_pointer] + replace_extents = pointer_to_extent[replace_pointer] return tvm.tir.Allocate( replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body ) @@ -404,10 +404,11 @@ def _visit_rewrite(stmt): if pointer_to_buffer[allocate_pointer] in rewrite_buffer: new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] new_pointer = rewrite_pointer[allocate_pointer] + assert len(new_buffer.shape) == 1 return tvm.tir.Allocate( new_pointer, new_buffer.dtype, - new_buffer.shape, + new_buffer.shape[0], stmt.condition, stmt.body, stmt.span, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 408eab6427ca..72c9661a8df9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -167,7 +167,7 @@ def populate_allocate_buffer_info(stmt): allocate = stmt buffer_info[allocate.buffer_var] = BufferInfo( None, - allocate.extents, + [allocate.extent], allocate.dtype, BufferType.scratch, ) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..19f7e669491a 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,8 +17,9 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container, PointerType, PrimType +from tvm.ir import container as _container, PointerType, PrimType, Range +from . import buffer as _buffer from . import stmt as _stmt from . import expr as _expr from . import op @@ -38,44 +39,45 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(ObjectGeneric): - """Buffer variable with content type, makes load store easily. +class BufferVarBuilder(ObjectGeneric): + """Helper to build Load/Store interactions with a buffer. - Do not create it directly, create use IRBuilder. + The BufferVarBuilder gives array access into physical memory. + Indices should be flat values, and are used in Load/Store nodes. - BufferVars support array access either via a linear index, or, if given a - shape, via a multidimensional index. + Do not create a BufferVarBuilder directly. Instead, use + `IRBuilder.allocate` or `IRBuilder.pointer`. Examples -------- - In the follow example, x is BufferVar. - :code:`x[0] = ...` directly emit a store to the IRBuilder, + In the follow example, x is BufferVarBuilder. + :code:`x[0] = ...` directly emit a Store to the IRBuilder, :code:`x[10]` translates to Load. .. code-block:: python - # The following code generate IR for x[0] = x[ ib = tvm.tir.ir_builder.create() + + # One-dimensional buffer access x = ib.pointer("float32") x[0] = x[10] + 1 - y = ib.allocate("float32", (32, 32)) - # Array access using a linear index + # Implementing multi-dimensional array access using a linear index + y = ib.allocate("float32", 32*32) y[(2*32) + 31] = 0. - # The same array access using a multidimensional index - y[2, 31] = 0. See Also -------- IRBuilder.pointer IRBuilder.buffer_ptr IRBuilder.allocate + IRBuilder.buffer_realize + """ - def __init__(self, builder, buffer_var, shape, content_type): + def __init__(self, builder, buffer_var, content_type): self._builder = builder self._buffer_var = buffer_var - self._shape = shape self._content_type = content_type def asobject(self): @@ -85,27 +87,20 @@ def asobject(self): def dtype(self): return self._content_type - def _linear_index(self, index): - if not isinstance(index, tuple) or self._shape is None: - return index - assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % ( - len(index), - len(self._shape), - ) - dim_size = 1 - lidx = 0 - for dim, idx in zip(reversed(self._shape), reversed(index)): - lidx += idx * dim_size - dim_size *= dim - return lidx - - def __getitem__(self, index): + def _normalize_index(self, index): t = DataType(self._content_type) - index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) index = _expr.Ramp(base, stride, t.lanes) + + if isinstance(index, _expr.IterVar): + index = index.var + + return index + + def __getitem__(self, index): + index = self._normalize_index(index) return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): @@ -114,13 +109,92 @@ def __setitem__(self, index, value): raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) ) - index = self._linear_index(index) + + index = self._normalize_index(index) + self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + + +class BufferBuilder(ObjectGeneric): + """Helper to build BufferLoad/BufferStore interactions with a buffer. + + The BufferBuilder gives multi-dimensional array access into + logical memory. Indices should have the same number of dimensions + as the underlying buffer. Read/writes to the BufferBuilder + correspond to BufferLoad/BufferStore nodes. For physical memory + access, see BufferVarBuilder. + + Do not create a BufferBuilder directly. Instead, use + `IRBuilder.buffer_realize` or `IRBuilder.buffer_ptr`. + + Examples + -------- + In the follow example, x is BufferVarBuilder. + :code:`x[0] = ...` directly emit a BufferStore to the IRBuilder, + :code:`x[10]` translates to BufferLoad. + + .. code-block:: python + + ib = tvm.tir.ir_builder.create() + # One-dimensional buffer access + x = ib.buffer_realize("float32", 16) + x[0] = x[10] + 1.0 + + # Multi-dimensional buffer access + y = ib.buffer_realize("float32", (16, 32)) + # Array access using a multidimensional index + y[2, 31] = 0.0 + + See Also + -------- + IRBuilder.pointer + IRBuilder.buffer_ptr + IRBuilder.allocate + IRBuilder.buffer_realize + + """ + + def __init__(self, builder, buffer, content_type): + self._builder = builder + self._buffer = buffer + self._content_type = content_type + + def asobject(self): + return self._buffer + + @property + def dtype(self): + return self._content_type + + def _normalize_index(self, index): + try: + index = [*index] + except TypeError: + index = [index] + t = DataType(self._content_type) if t.lanes > 1: - base = index * t.lanes + base = index[-1] * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + index[-1] = _expr.Ramp(base, stride, t.lanes) + + index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] + + return index + + def __getitem__(self, index): + index = self._normalize_index(index) + return _expr.BufferLoad(self._buffer, index) + + def __setitem__(self, index, value): + index = self._normalize_index(index) + + value = convert(value) + if value.dtype != self._content_type: + raise ValueError( + "data type does not match content type %s vs %s" % (value.dtype, self._content_type) + ) + + self._builder.emit(_stmt.BufferStore(self._buffer, value, index)) class IRBuilder(object): @@ -281,7 +355,7 @@ def while_loop(self, condition): .. code-block:: python ib = tvm.tir.ir_builder.create() - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") with ib.while_loop(iterations[0] < 10): iterations[0] += 1 """ @@ -394,7 +468,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=""): + def allocate(self, dtype, extent, name="buf", scope=""): """Create a allocate statement. Parameters @@ -402,8 +476,8 @@ def allocate(self, dtype, shape, name="buf", scope=""): dtype : str The content data type. - shape : tuple of Expr - The shape of array to be allocated. + extent : Expr + The size of array to be allocated. name : str, optional The name of the buffer. @@ -417,10 +491,39 @@ def allocate(self, dtype, shape, name="buf", scope=""): The buffer var representing the buffer. """ buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) + self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, extent, const(1, dtype="uint1"), x)) + return BufferVarBuilder(self, buffer_var, dtype) + + def buffer_realize(self, dtype, shape, name="buf", scope=""): + """Create a BufferRealize statement. + + Parameters + ---------- + dtype : str + The content data type. + + shape : Union[Expr, List[Expr], Tuple[Expr]] + The shape of array to be allocated. + + name : str, optional + The name of the buffer. + + scope : str, optional + The scope of the buffer. + + Returns + ------- + buffer : BufferBuilder + The buffer var representing the buffer. + """ + buffer = _buffer.decl_buffer(shape, dtype=dtype, name=name, scope=scope) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) - return BufferVar(self, buffer_var, shape, dtype) + + bounds = [Range(0, dim_extent) for dim_extent in shape] + + self.emit(lambda x: _stmt.BufferRealize(buffer, bounds, True, x)) + return BufferBuilder(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. @@ -438,14 +541,14 @@ def pointer(self, content_type, name="ptr", scope=""): Returns ------- - ptr : BufferVar + ptr : BufferVarBuilder The buffer var representing the buffer. """ buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) - return BufferVar(self, buffer_var, None, content_type) + return BufferVarBuilder(self, buffer_var, content_type) - def buffer_ptr(self, buf, shape=None): - """Create pointer variable corresponds to buffer ptr. + def buffer_ptr(self, buf): + """Create a handle to interact with the buffer specified. Parameters ---------- @@ -457,10 +560,10 @@ def buffer_ptr(self, buf, shape=None): Returns ------- - ptr : BufferVar + ptr : BufferBuilder The buffer var representing the buffer. """ - return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype) + return BufferBuilder(self, buf, buf.dtype) def likely(self, expr): """Add likely tag for expression. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index de200d5eabdd..ea6b3b6a7945 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -309,8 +309,8 @@ class Allocate(Stmt): dtype : str The data type of the buffer. - extents : list of Expr - The extents of the allocate + extent : Expr + The number of elements to allocate condition : PrimExpr The condition. @@ -325,14 +325,14 @@ class Allocate(Stmt): The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None): + def __init__(self, buffer_var, dtype, extent, condition, body, annotations=None, span=None): if annotations is None: annotations = dict() self.__init_handle_by_constructor__( _ffi_api.Allocate, # type: ignore buffer_var, dtype, - extents, + extent, condition, body, annotations, diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index e402c5888978..b6384cfb00bb 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -306,9 +306,7 @@ def _nms_loop( ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) - num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" - ) + num_valid_boxes_local = ib.allocate("int32", 1, name="num_valid_boxes_local", scope="local") num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, i, j, nkeep): @@ -345,7 +343,7 @@ def nms_inner_loop(ib, i, j, nkeep): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms # No need to do more iteration if we have already reached max_output_size boxes - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx = ib.allocate("int32", 1, name="box_idx", scope="local") box_idx[0] = 0 with ib.while_loop( tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) diff --git a/python/tvm/topi/cuda/rcnn/proposal.py b/python/tvm/topi/cuda/rcnn/proposal.py index 12f7a23abe35..6bb2e5c2054c 100644 --- a/python/tvm/topi/cuda/rcnn/proposal.py +++ b/python/tvm/topi/cuda/rcnn/proposal.py @@ -176,8 +176,8 @@ def argsort_ir(data_buf, out_index_buf): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "virtual_thread", nthread_bx) tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + temp_data = ib.allocate("float32", 1, name="temp_data", scope="local") + temp_index = ib.allocate("int32", 1, name="temp_index", scope="local") idxm = tvm.tir.indexmod @@ -299,14 +299,14 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): tx = te.thread_axis("threadIdx.x") ib = tvm.tir.ir_builder.create() ib.scope_attr(tx, "thread_extent", nthread_tx) - i = ib.allocate("int32", (1,), "i", scope="local") + i = ib.allocate("int32", 1, "i", scope="local") i[0] = 0 p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf) p_remove = ib.buffer_ptr(remove_mask_buf) p_out = ib.buffer_ptr(out_buf) b = tx - nkeep = ib.allocate("int32", (1,), "nkeep", scope="local") + nkeep = ib.allocate("int32", 1, "nkeep", scope="local") nkeep[0] = 0 # number of bbox after nms with ib.for_range(0, num_bbox) as j: diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0d19a92f2058..6c2151f7b573 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") + start = ib.allocate("int64", 1, name="start", scope="local") + middle = ib.allocate("int64", 1, name="middle", scope="local") + end = ib.allocate("int64", 1, name="end", scope="local") start[0] = width * tid with ib.if_scope(start[0] < scan_axis_size): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) @@ -159,10 +159,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate(out_dtype, (1,), name="end", scope="local") + start = ib.allocate("int64", 1, name="start", scope="local") + middle = ib.allocate("int64", 1, name="middle", scope="local") + end = ib.allocate("int64", 1, name="end", scope="local") + tmp = ib.allocate(out_dtype, 1, name="end", scope="local") start[0] = width * tid with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fa7545cd323a..09bc48404680 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -642,7 +642,7 @@ def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): ni = indices.shape[0] - atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(updates.dtype, 1, name="atomic_add_return", scope="local") with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) @@ -772,9 +772,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - atomic_add_return = ib.allocate( - updates.dtype, (1,), name="atomic_add_return", scope="local" - ) + atomic_add_return = ib.allocate(updates.dtype, 1, name="atomic_add_return", scope="local") fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 25cc7a4e2cfb..9cbe89947fb2 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -134,25 +134,25 @@ def _odd_even_sort( ## Create shared memory as syncable thread scratch space tmp_keys_swap = ib.allocate( keys_swap.dtype, - (block_size,), + block_size, name="temp_keys_swap", scope="shared", ) if values_swap is not None: tmp_values_swap = ib.allocate( values_swap.dtype, - (block_size,), + block_size, name="temp_values_swap", scope="shared", ) ## Create thread local data for swapping - temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local") + temp_keys = ib.allocate(keys_swap.dtype, 1, name="temp_keys", scope="local") if values_swap is not None: - temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local") + temp_values = ib.allocate(values_swap.dtype, 1, name="temp_values", scope="local") - temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local") - temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local") + temp_cond1 = ib.allocate(keys_swap.dtype, 1, name="temp_cond1", scope="local") + temp_cond2 = ib.allocate(keys_swap.dtype, 1, name="temp_cond2", scope="local") # Copy data to scratch space base_idx = by * size * axis_mul_after + bz with ib.for_range(0, 2) as n: @@ -255,9 +255,9 @@ def compare(a, b): upper_lim = ceil_log2(size) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): - first = ib.allocate("int64", (1,), name="first", scope="local") - mid = ib.allocate("int64", (1,), name="mid", scope="local") - last = ib.allocate("int64", (1,), name="last", scope="local") + first = ib.allocate("int64", 1, name="first", scope="local") + mid = ib.allocate("int64", 1, name="mid", scope="local") + last = ib.allocate("int64", 1, name="last", scope="local") first[0] = tvm.te.max(0, diag - bCount) last[0] = tvm.te.min(diag, aCount) with ib.while_loop(first[0] < last[0]): @@ -286,8 +286,8 @@ def serial_merge( first, last, ): - i = ib.allocate("int64", (1,), name="i", scope="local") - j = ib.allocate("int64", (1,), name="j", scope="local") + i = ib.allocate("int64", 1, name="i", scope="local") + j = ib.allocate("int64", 1, name="j", scope="local") i[0] = aStart + first j[0] = bStart + diag - last with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count: diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 32f20a15016e..866ac58eaf03 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -149,7 +149,6 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 - nnzb = w_data.shape[0] # treat csr like block size 1 bsr if len(w_data.shape) == 1: bs_n = 1 @@ -181,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k)) + w_data_ptr = ib.buffer_ptr(w_data) w_indices_ptr = ib.buffer_ptr(w_indices) w_indptr_ptr = ib.buffer_ptr(w_indptr) @@ -193,18 +192,20 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): rowlength_bo = ceil_div(w_indptr_ptr[n_index + 1] - row_start, rowlength_bi) # thread local storage for bs_m x bs_n block - block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local") - data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local") + block = ib.buffer_realize(data.dtype, (bs_m, bs_n), name="block", scope="local") + data_cache = ib.buffer_realize( + data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local" + ) if use_warp_storage: - indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") - w_data_cache = ib.allocate( + indices = ib.allocate(w_indices.dtype, rowlength_bi, name="indices", scope="warp") + w_data_cache = ib.buffer_realize( w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" ) else: - indices = ib.allocate( + indices = ib.buffer_realize( w_indices.dtype, (ni, rowlength_bi), name="indices", scope="shared" ) - w_data_cache = ib.allocate( + w_data_cache = ib.buffer_realize( w_data.dtype, (ni, rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="shared" ) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 7a796fa42696..53161a2126cb 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -88,22 +88,20 @@ def gen_ir( new_shape_size = new_shape_ptr.shape[0] multipliers = ib.allocate( - new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="global" - ) - dividers = ib.allocate( - new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="global" + new_shape_ptr.dtype, prev_shape_size, name="multipliers", scope="global" ) + dividers = ib.allocate(new_shape_ptr.dtype, new_shape_size, name="dividers", scope="global") flattened_indices = ib.allocate( new_shape_ptr.dtype, - (sparse_indices_ptr.shape[0],), + sparse_indices_ptr.shape[0], name="flattened_indices", scope="global", ) - total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="global") + total_ele = ib.allocate(new_shape_ptr.dtype, 1, name="total_ele", scope="global") division_total_ele = ib.allocate( - new_shape_ptr.dtype, (1,), name="division_total_ele", scope="global" + new_shape_ptr.dtype, 1, name="division_total_ele", scope="global" ) - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="global") + equal_shape = ib.allocate("bool", 1, name="equal_shape", scope="global") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): # The computation in this block is very very miniscule since we are just iterating over @@ -183,7 +181,7 @@ def gen_ir( with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): current_element = ib.allocate( - new_shape_ptr.dtype, (1,), name="current_element", scope="local" + new_shape_ptr.dtype, 1, name="current_element", scope="local" ) current_element[0] = flattened_indices[row_number] diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index e577104c3ddc..eb7a09a369ed 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -303,8 +303,8 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): with irb.for_range(0, nnz, kind="serial", name="nz_idx") as nz_idx: out_indptr_ptr[indices_ptr[nz_idx]] += 1 - cumsum = irb.allocate("int32", (1,), name="cumsum", scope="local") - temp = irb.allocate("int32", (1,), name="temp", scope="local") + cumsum = irb.allocate("int32", 1, name="cumsum", scope="local") + temp = irb.allocate("int32", 1, name="temp", scope="local") cumsum[0] = 0 with irb.for_range(0, n, kind="serial", name="col") as col: temp[0] = out_indptr_ptr[col] @@ -325,8 +325,8 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): out_data_ptr[dest] = data_ptr[real_idx] out_indptr_ptr[col] += 1 - last = irb.allocate("int32", (1,), name="last", scope="local") - temp2 = irb.allocate("int32", (1,), name="temp2", scope="local") + last = irb.allocate("int32", 1, name="last", scope="local") + temp2 = irb.allocate("int32", 1, name="temp2", scope="local") last[0] = 0 with irb.for_range(0, n, kind="serial", name="col") as col: temp2[0] = out_indptr_ptr[col] diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index 4d659c801103..31b6e4c06c90 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -81,7 +81,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out): _, N = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="row") as row: - dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + dot = irb.allocate(data.dtype, 1, name="dot", scope="local") out_ptr[row * N + n] = cast(0, data.dtype) dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index 3c2016c6513a..7d54133a5a5f 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -71,7 +71,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 with irb.for_range(0, num_rows, kind="parallel", name="row") as row: - dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + dot = irb.allocate(data.dtype, 1, name="dot", scope="local") out_ptr[row] = cast(0, data.dtype) dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] diff --git a/python/tvm/topi/sparse/dense.py b/python/tvm/topi/sparse/dense.py index 5c63e44f691a..e40ed15d2535 100644 --- a/python/tvm/topi/sparse/dense.py +++ b/python/tvm/topi/sparse/dense.py @@ -76,7 +76,7 @@ def dense_default_ir(data, indices, indptr, weight, out): N, K = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="m") as m: - dot = irb.allocate(dtype, (1,), name="dot", scope="local") + dot = irb.allocate(dtype, 1, name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) row_start = indptr_ptr[m] @@ -155,7 +155,7 @@ def dense_default_ir(data, w_data, w_indices, w_indptr, out): N = simplify(w_indptr.shape[0] - 1) with irb.for_range(0, M, kind="vectorize", name="m") as m: with irb.for_range(0, N, kind="parallel", name="n") as n: - dot = irb.allocate(dtype, (1,), name="dot", scope="local") + dot = irb.allocate(dtype, 1, name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) row_start = w_indptr_ptr[n] diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index b25bd854a7f9..7dc3bb2d96b2 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -89,19 +89,17 @@ def gen_ir( new_shape_size = new_shape_ptr.shape[0] multipliers = ib.allocate( - new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="local" - ) - dividers = ib.allocate( - new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="local" + new_shape_ptr.dtype, prev_shape_size, name="multipliers", scope="local" ) + dividers = ib.allocate(new_shape_ptr.dtype, new_shape_size, name="dividers", scope="local") flattened_indices = ib.allocate( new_shape_ptr.dtype, - (sparse_indices_ptr.shape[0],), + sparse_indices_ptr.shape[0], name="flattened_indices", scope="local", ) - total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="local") + total_ele = ib.allocate(new_shape_ptr.dtype, 1, name="total_ele", scope="local") total_ele[0] = prev_shape[0] # Cumulative Reverse Exclusive Multiply @@ -114,7 +112,7 @@ def gen_ir( total_ele[0] *= prev_shape[prev_shape_size - i] division_total_ele = ib.allocate( - new_shape_ptr.dtype, (1,), name="division_total_ele", scope="local" + new_shape_ptr.dtype, 1, name="division_total_ele", scope="local" ) division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size) as i: @@ -130,7 +128,7 @@ def gen_ir( with ib.else_scope(): out_new_shape[i] = new_shape[i] - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + equal_shape = ib.allocate("bool", 1, name="equal_shape", scope="local") # Check if prev_shape and new_shape are equal equal_shape[0] = True @@ -163,7 +161,7 @@ def gen_ir( with ib.for_range(0, new_sparse_indices_ptr.shape[0], kind="parallel") as i: current_element = ib.allocate( - new_shape_ptr.dtype, (1,), name="current_element", scope="local" + new_shape_ptr.dtype, 1, name="current_element", scope="local" ) current_element[0] = flattened_indices[i] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 7a51946a279a..6acaa9a3e299 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -655,9 +655,9 @@ def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" + "int32", 1, name="num_valid_boxes_local", scope="local" ) - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx = ib.allocate("int32", 1, name="box_idx", scope="local") num_valid_boxes_local[0] = 0 box_idx[0] = 0 diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index d12592fd111a..7d55468571d4 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -60,8 +60,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def binary_search(ib, y, num_boxes, scores, score_threshold, out): """Binary search for score_threshold on scores sorted in descending order""" - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") + lo = ib.allocate("int32", 1, name="lo", scope="local") + hi = ib.allocate("int32", 1, name="hi", scope="local") lo[0] = 0 hi[0] = num_boxes diff --git a/python/tvm/topi/vision/rcnn/proposal.py b/python/tvm/topi/vision/rcnn/proposal.py index 12a0d6bcf0a0..23b0d5d39ebf 100644 --- a/python/tvm/topi/vision/rcnn/proposal.py +++ b/python/tvm/topi/vision/rcnn/proposal.py @@ -205,8 +205,8 @@ def argsort_ir(data_buf, out_index_buf): ib = tvm.tir.ir_builder.create() p_data = ib.buffer_ptr(data_buf) index_out = ib.buffer_ptr(out_index_buf) - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + temp_data = ib.allocate("float32", 1, name="temp_data", scope="local") + temp_index = ib.allocate("int32", 1, name="temp_index", scope="local") idxm = tvm.tir.indexmod with ib.for_range(0, batch, kind="unroll") as b: start = b * num_bbox @@ -316,12 +316,12 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape) rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch ib = tvm.tir.ir_builder.create() - i = ib.allocate("int32", (batch,), "i", scope="local") + i = ib.allocate("int32", batch, "i", scope="local") p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf) p_remove = ib.buffer_ptr(remove_mask_buf) p_out = ib.buffer_ptr(out_buf) - nkeep = ib.allocate("int32", (batch,), "nkeep", scope="local") + nkeep = ib.allocate("int32", batch, "nkeep", scope="local") with ib.for_range(0, batch) as b: nkeep[b] = 0 diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fa132f079793..4ef4da8c7614 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -449,9 +449,8 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); - doc << "allocate(" << Print(op->buffer_var) << ", "; - doc << PrintDType(op->dtype) << ", "; - doc << Print(op->extents) << "), storage_scope = " << scope; + doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extent) << "), storage_scope = " << scope; if (!op->annotations.empty()) { std::vector attr_docs; for (const auto& it : op->annotations) { diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa74e56f491c..124fa0400557 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -764,7 +764,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto storage_scope = GetPtrStorageScope(op->buffer_var); if (current_num_ != num_child_ - 1) { - doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " + doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extent) << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); @@ -777,7 +777,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { doc << ") as " << Print(op->buffer_var) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents) + doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extent) << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 2ed5fd4029a2..b3199405916f 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -224,10 +224,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, 1, const_true(), body); if (!normal_red.empty()) { - body = - Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, 1, const_true(), body); } } body = Substitute(body, value_map); diff --git a/src/tir/analysis/calculate_workspace.cc b/src/tir/analysis/calculate_workspace.cc index 49ddaf613c6d..739fb9b99d52 100644 --- a/src/tir/analysis/calculate_workspace.cc +++ b/src/tir/analysis/calculate_workspace.cc @@ -55,15 +55,7 @@ size_t WorkspaceCalculator::GetByteAlignedSize(size_t non_aligned_size) { size_t WorkspaceCalculator::CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); - size_t num_elements = 1; - for (const auto& ext : op->extents) { - if (ext->IsInstance()) { - num_elements *= Downcast(ext)->value; - } else { - // We cant statically calculate workspace for dynamic shapes - num_elements = 0; - } - } + size_t num_elements = op->constant_allocation_size(); return GetByteAlignedSize(num_elements * element_size_bytes); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 24aacc3c04f7..41fe556e5688 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -291,6 +291,21 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp } } +int32_t BufferNode::NumElements() const { + int64_t result = 1; + for (const PrimExpr& dim : shape) { + if (const IntImmNode* int_size = dim.as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} + PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0d42c20c2822..f2f46e32af3b 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -332,18 +332,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Allocate -Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { +Allocate::Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr condition, Stmt body, + Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation << "). The data type should be an element of the pointer type."; - for (size_t i = 0; i < extents.size(); ++i) { - ICHECK(extents[i].defined()); - ICHECK(extents[i].dtype().is_scalar()); - } + ICHECK(extent.defined()); + ICHECK(extent.dtype().is_scalar()); ICHECK(body.defined()); ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); @@ -351,7 +349,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; - node->extents = std::move(extents); + node->extent = std::move(extent); node->condition = std::move(condition); node->body = std::move(body); node->annotations = std::move(annotations); @@ -359,25 +357,22 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim data_ = std::move(node); } -int32_t AllocateNode::constant_allocation_size(const Array& extents) { - int64_t result = 1; - for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode* int_size = extents[i].as()) { - result *= int_size->value; - if (result > std::numeric_limits::max()) { - return 0; - } - } else { - return 0; - } +int32_t AllocateNode::constant_allocation_size(const PrimExpr& extent) { + arith::Analyzer analyzer; + + PrimExpr simplified = analyzer.Simplify(extent); + + if (const IntImmNode* int_size = simplified.as()) { + return int_size->value; + } else { + return 0; } - return static_cast(result); } TVM_REGISTER_GLOBAL("tir.Allocate") - .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + .set_body_typed([](Var buffer_var, DataType type, PrimExpr extent, PrimExpr condition, Stmt body, Map annotations, Span span) { - return Allocate(buffer_var, type, extents, condition, body, annotations, span); + return Allocate(buffer_var, type, extent, condition, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateNode); @@ -389,10 +384,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } + p->stream << " * "; + p->Print(op->extent); p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..3b3a9f32ccbe 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -53,7 +53,7 @@ void StmtVisitor::VisitStmt_(const WhileNode* op) { } void StmtVisitor::VisitStmt_(const AllocateNode* op) { - VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->extent); this->VisitStmt(op->body); this->VisitExpr(op->condition); } @@ -304,15 +304,15 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { - Array extents = Internal::Mutate(this, op->extents); + PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { + if (extent.same_as(op->extent) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->extents = std::move(extents); + n->extent = std::move(extent); n->body = std::move(body); n->condition = std::move(condition); return Stmt(n); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 76845cbebd2a..b8f83db977ed 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -220,7 +220,7 @@ class BF16LowerRewriter : public StmtExprMutator { DataType dtype = DataType::UInt(16, op->dtype.lanes()); Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); var_remap_[op->buffer_var] = buffer_var; - return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); + return VisitStmt(Allocate(buffer_var, dtype, op->extent, op->condition, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 3b6af0644fc9..7b9be15a83ab 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -61,7 +61,7 @@ class BoundChecker : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { // If the shape was updated we should update the hashtable. if (UpdateIsNeeded(op->buffer_var)) { - Update(op->buffer_var, op->extents, op->dtype); + Update(op->buffer_var, op->extent, op->dtype); } return StmtExprMutator::VisitStmt_(op); } @@ -108,28 +108,14 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { + void Update(const Var& buffer_var, const PrimExpr new_extent, const DataType& type) { // Sanity check at first. - if (!new_shape.size()) { + if (!new_extent.defined() || !new_extent.dtype().is_scalar() || is_negative_const(new_extent)) { return; } - for (size_t i = 0; i < new_shape.size(); ++i) { - if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() || - is_negative_const(new_shape[i])) { - return; - } - } - - // Scalarize the shape. - PrimExpr shape = - Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); - for (size_t i = 1; i < new_shape.size(); ++i) { - // Cast to unsigned to avoid integer overlow at frist. - shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), - Cast(DataType::UInt(64), new_shape[i]))); - } - mem_to_shape_[buffer_var.get()] = shape; + // Define the extent including lanes. + mem_to_shape_[buffer_var.get()] = Mul(make_const(DataType::UInt(64), type.lanes()), new_extent); } bool IndexIsValid(const PrimExpr& index) const { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e0ab95a537e7..a2d31fada422 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -134,7 +134,7 @@ class BufferFlattener : public StmtExprMutator { static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); - body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); + body = Allocate(buffer->data, buffer->dtype, area, const_true(), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 0b45bde28dfe..0137b9133e87 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -107,19 +107,14 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = GetPtrStorageScope(op->buffer_var); - it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); + it->second.stride = op->extent * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - Array new_extents{make_const(op->extents[0].dtype(), 2)}; - for (PrimExpr e : op->extents) { - new_extents.push_back(e); - } + PrimExpr new_extent = mul(make_const(op->extent.dtype(), 2), op->extent); ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; alloc_nest.emplace_back( - Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); + Allocate(op->buffer_var, op->dtype, new_extent, op->condition, Evaluate(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4964bec0334e..1fa2265d75db 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -124,9 +124,7 @@ class VarTouchedAnalysis : public StmtVisitor { } void VisitStmt_(const AllocateNode* op) final { ExprTouched tc(touched_var_, false); - for (size_t i = 0; i < op->extents.size(); ++i) { - tc(op->extents[i]); - } + tc(op->extent); tc.VisitExpr(op->condition); Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); @@ -359,44 +357,30 @@ class VTInjector : public StmtExprMutator { return InjectVTLoop(GetRef(op), true); } - bool changed = false; - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); - } - if (!new_ext.same_as(op->extents[i])) changed = true; - extents.push_back(new_ext); + PrimExpr extent = this->VisitExpr(op->extent); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(GetRef(op), true); } + visit_touched_var_ = false; Stmt body; // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); - Array other; - other.push_back(make_const(op->extents[0].dtype(), num_threads_)); - for (PrimExpr e : extents) { - other.push_back(e); - } - extents = other; - changed = true; + PrimExpr stride = mul(op->extent, op->dtype.lanes()); + extent = mul(extent, num_threads_); // mark this buffer get touched. alloc_remap_[op->buffer_var.get()] = stride; - // Mutate the body. - body = this->VisitStmt(op->body); - } else { - // Mutate the body. - body = this->VisitStmt(op->body); } - if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { + + // Mutate the body. + body = this->VisitStmt(op->body); + + if (extent.same_as(op->extent) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return Allocate(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extent, condition, body); } } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 262906ade2e8..ebd39f6a1018 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -167,7 +167,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(new_var, op->dtype, op->extent, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 40d152b3b3b6..6a7fa2319b48 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -55,7 +55,7 @@ class AttrScopeLifter : public StmtMutator { // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); + return Allocate(op->buffer_var, op->dtype, op->extent, op->condition, body); } else { return stmt; } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 21f1b18d523b..c971c1a863a1 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -96,7 +96,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(allocate); allocate = stmt.as(); - return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition, + return Allocate(new_buffer_var, new_allocate_type, allocate->extent, allocate->condition, allocate->body); } else { return StmtExprMutator::VisitStmt_(allocate); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 6f7c09cdcf2d..155aabb342e1 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - return Allocate(remapped, op->dtype, op->extents, op->condition, body); + return Allocate(remapped, op->dtype, op->extent, op->condition, body); } return StmtExprMutator::VisitStmt_(op); } @@ -98,10 +98,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; @@ -256,7 +256,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Uses a local variable to store the shuffled data. // Later on, this allocation will be properly attached to this statement. Var var("t" + std::to_string(idx), ptr_type); - Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0)); + Stmt s = Allocate(var, types[idx], PrimExpr(1), pred, Evaluate(0)); local_vars.push_back(s); } @@ -340,8 +340,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var var = shared_bufs[i]; load_remap_[buffers[i]] = Load(types[i], var, index, pred); store_remap_[buffers[i]] = var; - Array extents{PrimExpr(1)}; - auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); + PrimExpr extent(1); + auto node = Allocate(var, types[i], extent, pred, Evaluate(0)); alloc_remap_[buffers[i]] = node; warp_allocs_.insert(node.get()); } @@ -381,7 +381,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = Allocate(shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + mul(PrimExpr(group_extent), PrimExpr(reduce_extent)), pred, Evaluate(0)); store_remap_[buffers[idx]] = shared_bufs[idx]; } } @@ -391,7 +391,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + body = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 062d67eef165..910d532b1921 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -128,10 +128,7 @@ class BuiltinLower : public StmtExprMutator { } } } - PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); - for (size_t i = 0; i < op->extents.size(); ++i) { - total_bytes = total_bytes * op->extents[i]; - } + PrimExpr total_bytes = make_const(op->extent.dtype(), nbytes) * op->extent; ICHECK(device_type_.defined()) << "Unknown device type in current IR"; ICHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 30ec148c37dd..92f061514a6c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -227,7 +227,7 @@ class WarpAccessRewriter : protected StmtExprMutator { warp_group_ = (alloc_size + (factor - 1)) / factor; alloc_size = warp_group_ * factor; - return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, + return Allocate(op->buffer_var, op->dtype, make_const(DataType::Int(32), alloc_size / width_), op->condition, this->VisitStmt(op->body)); } diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index e8865b260dc1..22cb041425ff 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -68,14 +68,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { align = std::max(align, alloc->dtype.bytes()); } for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->extents.size(), 1); buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; - merged_alloc_size_ += alloc->extents[0] * align; + merged_alloc_size_ += alloc->extent * align; } allocated = true; - auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, - const_true(), StmtExprMutator::VisitStmt(op->body)); + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), merged_alloc_size_, const_true(), + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } return StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index aae1749b27db..6ee05b336344 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -81,7 +81,7 @@ class NoOpRemover : public StmtMutator { Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; + return is_no_op(op->body) ? MakeEvaluate(op->extent) : stmt; } Stmt VisitStmt_(const ProducerRealizeNode* op) final { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 795ae9d6a73a..2509499b156c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -95,12 +95,8 @@ class VarUseDefAnalysis : public StmtExprMutator { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + ICHECK(op->extent.defined()); + dyn_shmem_size_ = op->extent * op->dtype.bytes(); use_dyn_shmem_ = true; } return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6a3ce596c2fe..67436ed833d6 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1134,7 +1134,7 @@ class StorageFlattener : public StmtExprMutator { // use small alignment for small arrays auto dtype = op->buffer->dtype; - int32_t const_size = AllocateNode::constant_allocation_size(shape); + int32_t const_size = op->buffer->NumElements(); int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); @@ -1163,14 +1163,12 @@ class StorageFlattener : public StmtExprMutator { if (strides.size() != 0) { int first_dim = 0; ret = Allocate(e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + e.buffer->strides[first_dim] * e.buffer->shape[first_dim], make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { - shape = e.buffer->shape; - if (shape.size() == 0) { - shape.push_back(make_const(DataType::Int(32), 1)); - } - ret = Allocate(e.buffer->data, storage_type, shape, + PrimExpr extent = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), e.buffer->shape); + ret = Allocate(e.buffer->data, storage_type, extent, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 409b7c262954..9ce0d94cb6b2 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -551,10 +551,8 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), e->allocs[0]->extents); - e->new_alloc = - Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); + e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extent, + e->allocs[0]->condition, Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -565,8 +563,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents); + PrimExpr sz = op->extent; auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { @@ -594,8 +591,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = analyzer_.Simplify(combo_size); - e->new_alloc = - Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); + e->new_alloc = Allocate(e->alloc_var, alloc_type, combo_size, const_true(), Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -636,8 +632,8 @@ class StoragePlanRewriter : public StmtExprMutator { } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); PrimExpr alloc_size = - make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); - e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)); + make_const(e->allocs[0]->extent.dtype(), (total_bits + type_bits - 1) / type_bits); + e->new_alloc = Allocate(e->alloc_var, e->elem_type, alloc_size, const_true(), Evaluate(0)); if (info.defined()) { ICHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -1025,9 +1021,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateNode* op) final { - const Array& extents = op->extents; - PrimExpr extent = extents[extents.size() - 1]; - OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + OnArrayDeclaration(op->buffer_var, op->dtype, op->extent, BufferVarInfo::kAllocateNode); StmtExprVisitor::VisitStmt_(op); } @@ -1342,10 +1336,8 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); + PrimExpr extent = op->extent / make_const(op->extent.dtype(), factor); + return Allocate(new_buffer_var, info.new_element_dtype, extent, op->condition, op->body); } /* Update the parameters and all remaining variable references diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..9574628d3773 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -133,6 +133,11 @@ class LoopUnroller : public StmtExprMutator { } } + Stmt VisitStmt_(const BufferStoreNode* op) final { + ++step_count_; + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const StoreNode* op) final { ++step_count_; return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 4143577a0b17..b37b67019593 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -66,7 +66,7 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + return Allocate(remapped, op->dtype, op->extent, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cd2d230f5775..8fca62308be2 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -434,21 +434,17 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (new_ext.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(GetRef(op)); - } - extents.push_back(new_ext); + PrimExpr extent = this->VisitExpr(op->extent); + if (extent.dtype().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(GetRef(op)); } // place the vector lanes in least significant dimension. - extents.push_back(var_lanes_); + extent *= var_lanes_; // rewrite access to buffer internally. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return Allocate(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extent, condition, body); } // scalarize the statment diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 97809b0e1398..59d72b359f2d 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -169,7 +169,7 @@ TEST(IRF, StmtVisitor) { Stmt body = Evaluate(z); DataType dtype = DataType::Float(32); Var buffer("b", PointerType(PrimType(dtype))); - return Allocate(buffer, dtype, {z, z}, const_true(), body); + return Allocate(buffer, dtype, z * z, const_true(), body); }; v(fmaketest()); ICHECK_EQ(v.count, 3); @@ -215,7 +215,15 @@ TEST(IRF, StmtMutator) { Stmt body = Evaluate(z); DataType dtype = DataType::Float(32); Var buffer("b", PointerType(PrimType(dtype))); - return Allocate(buffer, dtype, {1, z}, const_true(), body); + return Allocate(buffer, dtype, z, const_true(), body); + }; + + auto fmakealloc_seq_body = [&]() { + auto z = x + 1; + Stmt body = Evaluate(z); + DataType dtype = DataType::Float(32); + Var buffer("b", PointerType(PrimType(dtype))); + return Allocate(buffer, dtype, z, const_true(), SeqStmt({body, body, body})); }; auto fmakeif = [&]() { @@ -225,23 +233,38 @@ TEST(IRF, StmtMutator) { }; MyVisitor v; + { - auto body = fmakealloc(); - Stmt body2 = Evaluate(1); - Stmt bref = body.as()->body; - auto* extentptr = body.as()->extents.get(); - Array arr{std::move(body), body2, body2}; - auto* arrptr = arr.get(); - arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr.get() == arrptr); - // inplace update body - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(arr[0].as()->extents.get() == extentptr); - // copy because there is additional refs - ICHECK(!arr[0].as()->body.same_as(bref)); - ICHECK(arr[0].as()->body.as()->value.same_as(x)); - ICHECK(bref.as()->value.as()); + // Inplace update of a CopyOnWrite body if there are no additional references. + auto before = fmakealloc_seq_body(); + const AllocateNode* alloc_ptr = before.as(); + const SeqStmtNode* before_body_ptr = before.as()->body.as(); + auto after = v(std::move(before)); + + // We get the same AllocateNode, and the same SeqStmt inside it. + ICHECK_EQ(after.get(), alloc_ptr); + auto after_body_ptr = after.as()->body.as(); + ICHECK_EQ(after_body_ptr, before_body_ptr); + // Verify that the change did actually happen. + ICHECK(after_body_ptr->seq[0].as()->value.same_as(x)); + } + + { + // Copy a CopyOnWrite body if there are additional references. + auto before = fmakealloc_seq_body(); + auto extra_ref = before.as()->body; + const AllocateNode* alloc_ptr = before.as(); + const SeqStmtNode* before_body_ptr = before.as()->body.as(); + auto after = v(std::move(before)); + + // We get the same AllocateNode, but a different SeqStmt inside it. + ICHECK_EQ(after.get(), alloc_ptr); + auto after_body_ptr = after.as()->body.as(); + ICHECK_NE(after_body_ptr, before_body_ptr); + // Verify that the change did actually happen. + ICHECK(after_body_ptr->seq[0].as()->value.same_as(x)); } + { Array arr{fmakealloc()}; // mutate array get reference by another one, triiger copy. @@ -249,8 +272,8 @@ TEST(IRF, StmtMutator) { auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() != arrptr); - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(!arr2[0].as()->extents[1].same_as(x)); + ICHECK(arr[0].as()->extent.same_as(x)); + ICHECK(!arr2[0].as()->extent.same_as(x)); // mutate but no content change. arr2 = arr; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); @@ -265,7 +288,6 @@ TEST(IRF, StmtMutator) { arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr2.get() == arr.get()); } - { auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1})); @@ -274,9 +296,9 @@ TEST(IRF, StmtMutator) { } { Stmt body = fmakealloc(); + auto* ref1 = body.get(); Stmt body2 = Evaluate(1); auto* ref2 = body2.get(); - auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); body = SeqStmt({body, body2}); @@ -284,7 +306,7 @@ TEST(IRF, StmtMutator) { body = v(std::move(body)); // the seq get flattened ICHECK(body.as()->size() == 3); - ICHECK(body.as()->seq[0].as()->extents.get() == extentptr); + ICHECK(body.as()->seq[0].get() == ref1); ICHECK(body.as()->seq[1].get() == ref2); } @@ -292,14 +314,14 @@ TEST(IRF, StmtMutator) { // Cannot cow because of bref Stmt body = fmakealloc(); Stmt body2 = Evaluate(1); - auto* extentptr = body.as()->extents.get(); + auto* extentptr = body.as()->extent.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); auto bref = body; body = SeqStmt({body, body2}); body = v(std::move(body)); // the seq get flattened - ICHECK(body.as()->seq[0].as()->extents.get() != extentptr); + ICHECK(body.as()->seq[0].as()->extent.get() != extentptr); } { @@ -317,8 +339,8 @@ TEST(IRF, StmtMutator) { body = v(std::move(block_realize)); // the body should be changed Block new_block = body.as()->block; - ICHECK(new_block->body.as()->extents[1].same_as(x)); - ICHECK(new_block->init.as()->extents[1].same_as(x)); + ICHECK(new_block->body.as()->extent.same_as(x)); + ICHECK(new_block->init.as()->extent.same_as(x)); ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 60ed352edcfd..3a2f1411bf0e 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -45,8 +45,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([128], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_global = T.allocate(128, "uint8", "global") + placeholder_d_global = T.allocate(32, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -119,7 +119,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") + ethosu_write_2 = T.allocate(4096, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None @@ -187,9 +187,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + ethosu_write_2 = T.allocate(4096, "int8", "global") + placeholder_global = T.allocate(80, "uint8", "global") + placeholder_d_global = T.allocate(32, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index f76a59dd1eb3..7c172c18d1d6 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -202,7 +202,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1024], "int8", "global") + ethosu_write_2 = T.allocate(1024, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -223,7 +223,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1536], "int8", "global") + ethosu_write_2 = T.allocate(1536, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) @@ -244,7 +244,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2560], "int8", "global") + ethosu_write_2 = T.allocate(2560, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -267,7 +267,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2304], "int8", "global") + ethosu_write_2 = T.allocate(2304, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 76b7ef2a70ee..2c869ea51849 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -39,8 +39,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([304], "uint8", "global") - placeholder_d_global = T.allocate([80], "uint8", "global") + placeholder_global = T.allocate(304, "uint8", "global") + placeholder_d_global = T.allocate(80, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 8240b392a1cf..79d854cf3bc5 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -60,8 +60,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_8 = T.match_buffer(placeholder_2, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_4, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") - ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") + ethosu_conv2d_2 = T.allocate(1024, "uint8", "global") + ethosu_conv2d_3 = T.allocate(2048, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9.data, 0), 0, 12, T.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) @@ -82,8 +82,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder_2, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) placeholder_4 = T.match_buffer(placeholder_1, [8, 1, 1, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([256], "uint8", "global") - placeholder_d_global = T.allocate([8], "int32", "global") + placeholder_global = T.allocate(256, "uint8", "global") + placeholder_d_global = T.allocate(8, "int32", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4.data, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5.data, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) @@ -109,8 +109,8 @@ def main(placeholder: T.handle, ethosu_conv2d: T.handle, placeholder_1: T.handle placeholder_9 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer = T.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([144], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") + placeholder_global = T.allocate(144, "uint8", "global") + placeholder_d_global = T.allocate(20, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) @@ -148,9 +148,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, ethosu_conv2d: T.handle buffer_4 = T.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_8 = T.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_conv2d_2 = T.allocate([4096], "uint8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") + ethosu_conv2d_2 = T.allocate(4096, "uint8", "global") + placeholder_global = T.allocate(80, "uint8", "global") + placeholder_d_global = T.allocate(20, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_5.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 8c8d601672ac..dd3508d2ed6b 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -761,7 +761,7 @@ def test_llvm_lower_atomic(): def do_atomic_add(A): ib = tvm.tir.ir_builder.create() n = A.shape[0] - atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(A.dtype, 1, name="atomic_add_return", scope="local") one = tvm.tir.const(1, A.dtype) A_ptr = ib.buffer_ptr(A) with ib.for_range(0, n, name="i", kind="parallel") as i: @@ -787,7 +787,7 @@ def test_llvm_gpu_lower_atomic(): def do_atomic_add(A): ib = tvm.tir.ir_builder.create() n = A.shape[0] - atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(A.dtype, 1, name="atomic_add_return", scope="local") one = tvm.tir.const(1, A.dtype) A_ptr = ib.buffer_ptr(A) nthread_tx = 64 diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 1edc5d311759..bf29b3cd20f2 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -333,7 +333,7 @@ def do_compute(A, B, n): if "gpu" in target.keys: ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") iterations[0] = 0 B[0] = 0 @@ -503,10 +503,10 @@ def do_compute(ins, outs): store_index = index_map[store_type] if indirect_indices: - load_index = tvm.tir.expr.Load("int32x4", R, load_index) + load_index = tvm.tir.expr.Load("int32x4", R.asobject().data, load_index) - transfer = tvm.tir.expr.Load("int32x4", A, load_index) - ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + transfer = tvm.tir.expr.Load("int32x4", A.asobject().data, load_index) + ib.emit(tvm.tir.stmt.Store(B.asobject().data, transfer, store_index)) return ib.get() @@ -536,7 +536,7 @@ def do_compute(ins, outs): ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) - array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared") + array = ib.allocate("int32", alloc_nbytes, name="array", scope="shared") array[0] = 0 out[0] = array[0] diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..db7fd7b1624b 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -607,7 +607,7 @@ def collect_visit(stmt, f): def visit_stmt(op): if isinstance(op, tvm.tir.Allocate): - return op.extents[0].value == 97 + return op.extent.value == 97 return False assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse))) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 4b61625014e2..9a7250f2178c 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -31,8 +31,8 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_49 = T.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_22 = T.allocate([131072], "int16", "global") - DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") + PaddedInput_22 = T.allocate(131072, "int16", "global") + DepthwiseConv2d_9 = T.allocate(100352, "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), T.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): @@ -62,25 +62,25 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_77 = T.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") + PaddedInput_25 = T.allocate(1*16*16*512, "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), T.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), T.int16(0), dtype="int16") - T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: + T_add_11 = T.allocate(1*14*14*512, "int32", "global") + with T.allocate(1*14*14*512, "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (T.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (T.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*T.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (T.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + T.load("int32", placeholder_167.data, ax3_47)) - compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: + compute_22 = T.allocate(1*14*14*512, "int32", "global") + with T.allocate(1*14*14*512, "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") - T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: + T_cast_79 = T.allocate(1*14*14*512, "uint8", "global") + with T.allocate(1*14*14*512, "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(T.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1aae8cdd03e1..cf885cd44bad 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -46,7 +46,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: with T.block([]): T.reads([]) T.writes(B[0:16, 0:16]) - A = T.allocate([256], "float32", "global") + A = T.allocate(256, "float32", "global") for i, j in T.grid(16, 16): T.store(A, i * 16 + j, 1) for i in range(0, 16): diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 00aba46ba431..2439864a6c72 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -155,7 +155,7 @@ def test_stmt_constructor(): assert x.value.value == 1 buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", 10, tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var @@ -163,7 +163,7 @@ def test_stmt_constructor(): storage_scope = "global.texture" buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope)) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", 10, tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 5b123e883849..cff266ab0e0b 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -184,7 +184,7 @@ def test_ir(A, B, C): A = ib.buffer_ptr(A) B = ib.buffer_ptr(B) C = ib.buffer_ptr(C) - i = ib.allocate("int32", (1,), name="i", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") i[0] = 0 with ib.for_range(0, n) as j: @@ -242,8 +242,8 @@ def collatz_ref(n): return i def collatz(ib, n, C): - i = ib.allocate("int32", (1,), name="i", scope="local") - a = ib.allocate("int32", (1,), name="a", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") + a = ib.allocate("int32", 1, name="a", scope="local") i[0] = 0 a[0] = n with ib.while_loop(a[0] > 1): @@ -317,9 +317,9 @@ def complex_sqr(z): return pixels def mandel(ib, i, j, pixels): - z = ib.allocate("float32", (2,), name="z", scope="local") - tmp = ib.allocate("float32", (1,), name="tmp", scope="local") - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + z = ib.allocate("float32", 2, name="z", scope="local") + tmp = ib.allocate("float32", 1, name="tmp", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") z[0] = (i / float(n) - 1) * 2 z[1] = (j / float(n) - 0.5) * 2 @@ -409,8 +409,8 @@ def check_target(target, ir): def test_while_binary_search(): def binary_search(ib, n, i, Aptr, Bptr, Cptr): - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") + lo = ib.allocate("int32", 1, name="lo", scope="local") + hi = ib.allocate("int32", 1, name="hi", scope="local") lo[0] = 0 hi[0] = n @@ -509,7 +509,7 @@ def test_device_ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size + temp = ib.allocate(dtype, n, scope="shared.dyn") # n is symbolic size Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index fe719ee99693..c663dd0dd627 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -481,7 +481,7 @@ def test_tir_allocate(): allocate = tvm.tir.Allocate( buffer_var=a, dtype=dtype, - extents=[2, 2], + extent=4, condition=tvm.get_global_func("tir.const_true")(dtype, None), body=tvm.tir.Evaluate(2 + 1), annotations={ @@ -491,7 +491,7 @@ def test_tir_allocate(): ) assert allocate.buffer_var == a assert allocate.dtype == "int8" - assert list(allocate.extents) == [2, 2] + assert allocate.extent == 4 assert allocate.annotations["attr1"] == "foo" assert allocate.annotations["attr2"] == "bar" @@ -500,7 +500,7 @@ def test_tir_allocate(): output = func.astext() assert ( output.find( - 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' + 'allocate(buffer: Pointer(global int8), int8, 4), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' ) != -1 ) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index a91fa2591e00..aca80dc91f86 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -31,17 +31,19 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + PaddedInput_3 = T.buffer_decl([1,28,28,192], dtype='int16', scope='global') + T.realize(PaddedInput_3[0:1, 0:28, 0:28, 0:192], '') for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + PaddedInput_3[0, i0_i1_fused_3, i2_3, i3_3] = placeholder_33[0, i0_i1_fused_3, i2_3, i3_3] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): - Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3 = T.buffer_decl([1], dtype='int32', scope='global') + T.realize(Conv2dOutput_3[0:1], '') + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_3[0] = Conv2dOutput_3[0] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3//28, ax0_ax1_fused_ax2_fused_3%28, rc_3], "int32")*T.cast(placeholder_34[0,0,rc_3,ax3_2], "int32") + T_cast_9[ax0_ax1_fused_ax2_fused_3//28, ax0_ax1_fused_ax2_fused_3%28, ax3_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[0,0,0,ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 21c896c7bb7e..6d5241b2a4c9 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -53,7 +53,7 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new = T.allocate([16], "float32", "global") + B_new = T.allocate(16, "float32", "global") for j in T.serial(0, 16): B_new[j] = T.load("float32", A.data, ((i * 16) + j)) + 1.0 for j in T.serial(0, 16): @@ -95,7 +95,7 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B = T.allocate(16, "float32", "local") for j in range(0, 16): B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 for j in range(0, 16): @@ -130,7 +130,7 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.allocate([m], "float32", "global") + B = T.allocate(m, "float32", "global") for j in range(0, m): B[j] = T.load("float32", A.data, i * m + j) + 1.0 for j in range(0, m): @@ -203,8 +203,8 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B = T.allocate((32,), "float32", "global") - C = T.allocate((32,), "float32", "global") + B = T.allocate(32, "float32", "global") + C = T.allocate(32, "float32", "global") B[i] = T.load("float32", A.data, i) + 1.0 C[i] = T.load("float32", A.data, i) + T.load("float32", B, i) D.data[i] = T.load("float32", C, i) * 2.0 @@ -238,7 +238,7 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.serial(0, 4): - B_new = T.allocate([68], "float32", "global") + B_new = T.allocate(68, "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = T.load("float32", A.data, i0 * 64 + i1 * 16 + j) + 1.0 diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 9b37bcaaacbc..821dc9a3cffa 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,10 +47,13 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body + # Allocation of B is now twice as large assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert stmt.body.extent.value == 2 * m + + mod = tvm.tir.transform.ThreadSync("shared")(mod) + f = mod["db"] - f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 673267a9b1fa..97f7ca9db617 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -14,95 +14,124 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys + +import pytest + import tvm +import tvm.testing from tvm import te +vthread_name = tvm.testing.parameter( + "vthread", + "cthread", +) +buffer_size = tvm.testing.parameter(4) +nthread = tvm.testing.parameter(2) -def test_vthread(): - dtype = "int64" - n = 100 - m = 4 - nthread = 2 - - def get_vthread(name): - tx = te.thread_axis(name) - ty = te.thread_axis(name) - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - with ib.for_range(0, n) as i: - ib.scope_attr(tx, "virtual_thread", nthread) - ib.scope_attr(ty, "virtual_thread", nthread) - B = ib.allocate("float32", m, name="B", scope="shared") - B[i] = A[i * nthread + tx] - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) - ib.emit( - tvm.tir.call_extern( - "int32", - "Run", - bbuffer.access_ptr("r"), - tvm.tir.call_intrin("int32", "tir.tvm_context_id"), - ) - ) - C[i * nthread + tx] = B[i] + 1 - return ib.get() - - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"] - - assert stmt.body.body.extents[0].value == 2 - - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] - - assert len(stmt.body.body.extents) == 3 - - -def test_vthread_extern(): - dtype = "int64" - n = 100 - m = 4 - nthread = 2 - - def get_vthread(name): - tx = te.thread_axis(name) - ty = te.thread_axis(name) - ib = tvm.tir.ir_builder.create() - with ib.for_range(0, n) as i: - ib.scope_attr(tx, "virtual_thread", nthread) - ib.scope_attr(ty, "virtual_thread", nthread) - A = ib.allocate("float32", m, name="A", scope="shared") - B = ib.allocate("float32", m, name="B", scope="shared") - C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.tir.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) - abuffer = tvm.tir.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) - A[tx] = tx + 1.0 - B[ty] = ty + 1.0 - ib.emit( - tvm.tir.call_extern( - "int32", - "Run", - abuffer.access_ptr("r"), - bbuffer.access_ptr("r"), - cbuffer.access_ptr("rw"), - ) + +@tvm.testing.fixture +def vthread_mod(vthread_name, buffer_size, nthread): + loop_extent = 100 + + tx = te.thread_axis(vthread_name) + ty = te.thread_axis(vthread_name) + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + with ib.for_range(0, loop_extent) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + ib.scope_attr(ty, "virtual_thread", nthread) + B = ib.allocate("float32", buffer_size, name="B", scope="shared") + B[i] = A[i * nthread + tx] + bbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=B.dtype, data=B.asobject()) + ib.emit( + tvm.tir.call_extern( + "int32", + "Run", + bbuffer.access_ptr("r"), + tvm.tir.call_intrin("int32", "tir.tvm_context_id"), ) - return ib.get() + ) + C[i * nthread + tx] = B[i] + 1 - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) - assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.extents) == 3 +def test_vthread(vthread_mod, vthread_name, buffer_size, nthread): + mod = tvm.tir.transform.InjectVirtualThread()(vthread_mod) + stmt = mod["main"] -def test_vthread_if_then_else(): - nthread = 2 + if vthread_name == "vthread": + # All virtual thread axes that starts with "vthread" share the + # same iteration, similar to threadIdx.x, so the number of + # virtual threads is nthread. + expected_buffer_size = buffer_size * nthread + elif vthread_name == "cthread": + # All other virtual thread axes are independent, so tx and ty + # are independent and the total number of virtual threads is + # nthread*nthread. + expected_buffer_size = buffer_size * nthread * nthread + else: + raise ValueError(f"Unexpected vthread_name: {vthread_name}") + + assert stmt.body.body.extent.value == expected_buffer_size + + +@tvm.testing.fixture +def vthread_extern_mod(vthread_name, buffer_size, nthread): + loop_extent = 100 + + tx = te.thread_axis(vthread_name) + ty = te.thread_axis(vthread_name) + ib = tvm.tir.ir_builder.create() + with ib.for_range(0, loop_extent) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + ib.scope_attr(ty, "virtual_thread", nthread) + A = ib.allocate("float32", buffer_size, name="A", scope="shared") + B = ib.allocate("float32", buffer_size, name="B", scope="shared") + C = ib.allocate("float32", buffer_size, name="C", scope="shared") + cbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=C.dtype, data=C.asobject()) + abuffer = tvm.tir.decl_buffer((buffer_size,), dtype=A.dtype, data=A.asobject()) + bbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=B.dtype, data=B.asobject()) + A[tx] = tx + 1.0 + B[ty] = ty + 1.0 + ib.emit( + tvm.tir.call_extern( + "int32", + "Run", + abuffer.access_ptr("r"), + bbuffer.access_ptr("r"), + cbuffer.access_ptr("rw"), + ) + ) + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) + + +def test_vthread_extern(vthread_extern_mod, vthread_name, buffer_size, nthread): + mod = tvm.tir.transform.InjectVirtualThread()(vthread_extern_mod) + stmt = mod["main"] + + if vthread_name == "vthread": + # The shared A and B buffers are only exposed as read-only to + # the external function, so they can still share the allocated + # space. + ro_buffer_size = buffer_size * nthread + rw_buffer_size = buffer_size * nthread * nthread + elif vthread_name == "cthread": + ro_buffer_size = buffer_size * nthread * nthread + rw_buffer_size = buffer_size * nthread * nthread + else: + raise ValueError(f"Unexpected vthread_name: {vthread_name}") + + A_alloc = stmt.body.body + C_alloc = A_alloc.body.body + assert A_alloc.extent.value == ro_buffer_size + assert C_alloc.extent.value == rw_buffer_size + + +@tvm.testing.fixture +def vthread_if_then_else_mod(nthread): tx = te.thread_axis("vthread") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -115,17 +144,15 @@ def test_vthread_if_then_else(): B[i] = A[i * nthread + tx] + 1 with ib.if_scope(i == 0): B[i] = A[i * nthread + tx] + 2 - stmt = ib.get() + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) + - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"] +def test_vthread_if_then_else(vthread_if_then_else_mod): + stmt = tvm.tir.transform.InjectVirtualThread()(vthread_if_then_else_mod)["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None if __name__ == "__main__": - test_vthread_extern() - test_vthread() - test_vthread_if_then_else() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 63772dea65d7..7b38fc6120c3 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -157,8 +157,8 @@ def build_tir(): Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32")) # return handle # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1] - Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject())) - ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1)) + Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject().data)) + ib.emit(tvm.tir.Store(Aptr_var, tvm.tir.const(expected_value[1], "float32"), 1)) stmt = ib.get() return tvm.IRModule.from_expr( diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 675a7feb3b1f..43c45ba457e7 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -22,15 +22,19 @@ import tvm.testing -@tvm.testing.requires_cuda -def test_lower_warp_memory_local_scope(): - m = 128 - A = te.placeholder((m,), name="A") - B = te.compute((m,), lambda i: A[i] + 3, name="B") +@tvm.testing.parametrize_targets("cuda") +def test_lower_warp_memory_local_scope(target): + target = tvm.target.Target(target) + assert target.thread_warp_size == 32 + + arr_size = 128 + cache_size = 64 + A = te.placeholder((arr_size,), name="A") + B = te.compute((arr_size,), lambda i: A[i] + 3, name="B") s = te.create_schedule(B.op) AA = s.cache_read(A, "warp", [B]) - xo, xi = s[B].split(B.op.axis[0], 64) + xo, xi = s[B].split(B.op.axis[0], cache_size) xi0, xi1 = s[B].split(xi, factor=32) tx = te.thread_axis("threadIdx.x") s[B].bind(xi1, tx) @@ -39,17 +43,16 @@ def test_lower_warp_memory_local_scope(): xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + mod = tvm.tir.transform.LowerWarpMemory()(mod) + fdevice = mod["f_kernel0"] allocate = fdevice.body.body assert allocate.buffer_var.type_annotation.storage_scope == "local" - assert fdevice.body.body.extents[0].value == 2 + assert fdevice.body.body.extent.value * target.thread_warp_size == cache_size @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 9d917466758b..649e7e6064d5 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -132,7 +132,7 @@ def test_body(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) B = tvm.tir.decl_buffer(name="B", shape=[1]) - C = ib.buffer_ptr(A) + C = ib.buffer_ptr(A.data) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B, C], stmt)) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..00ec510a9759 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -41,21 +41,23 @@ def run_passes(sch, args): def verify_single_allocation(stmt, alloc_size=None): num_alloc = [0] - alloc_extents = [] + alloc_extent = 1 def verify(n): + nonlocal alloc_extent + if ( isinstance(n, tvm.tir.Allocate) and n.buffer_var.type_annotation.storage_scope == "shared.dyn" ): num_alloc[0] += 1 - alloc_extents.append(n.extents[0]) + alloc_extent *= n.extent tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 if alloc_size: - assert alloc_extents[0] == alloc_size + assert alloc_extent == alloc_size @tvm.testing.requires_gpu @@ -80,12 +82,12 @@ def test_matmul_ir(A, B, C): ib.scope_attr(bx, "thread_extent", n // block) ib.scope_attr(by, "thread_extent", n // block) - A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 - B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + A_sh = ib.buffer_realize(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.buffer_realize(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. - C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + C_sh = ib.buffer_realize(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) @@ -155,8 +157,8 @@ def test_device_ir(A, B, C): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + A_sh = ib.allocate(A.dtype, n, scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, n, scope="shared.dyn") # fp32 Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -218,9 +220,9 @@ def test_device_ir(A, B, C, D): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") - C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + A_sh = ib.allocate(A.dtype, n, scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, n, scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, C.shape[0], scope="shared.dyn", name="C_sh") Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..ba508a6c0b4a 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -81,25 +81,25 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.extents[0].value == 17 * 8 + assert stmt.extent.value == 17 * 8 def test_flatten_double_buffer(): dtype = "int64" n = 100 - m = 4 + buffer_size = 4 tx = te.thread_axis("threadIdx.x") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") C = ib.pointer("float32", name="C") ib.scope_attr(tx, "thread_extent", 1) with ib.for_range(0, n) as i: - B = ib.allocate("float32", m, name="B", scope="shared") + B = ib.allocate("float32", buffer_size, name="B", scope="shared") with ib.new_scope(): ib.scope_attr(B.asobject(), "double_buffer_scope", 1) - with ib.for_range(0, m) as j: + with ib.for_range(0, buffer_size) as j: B[j] = A[i * 4 + j] - with ib.for_range(0, m) as j: + with ib.for_range(0, buffer_size) as j: C[j] = B[j] + 1 stmt = ib.get() @@ -119,7 +119,7 @@ def test_flatten_double_buffer(): stmt = mod["main"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert stmt.body.extent.value == 2 * buffer_size mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..22446f8b6516 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -87,7 +87,7 @@ def test_alloc_seq(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 200 + assert n.extent.value == 200 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -137,7 +137,7 @@ def offset_generater(dtype_list, length): def dtype_test(dtype_list, length): def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == offset + assert n.extent.value == offset body = stmt_generater(dtype_list, length) offset = offset_generater(dtype_list, length) @@ -222,7 +222,7 @@ def test_storage_combine(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 16 + assert n.extent.value == 16 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 @@ -527,7 +527,7 @@ def test_inplace_rule3(): # verify inplace folding works def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == 70 + assert n.extent.value == 70 tvm.tir.stmt_functor.post_order_visit(stmt, verify) @@ -560,7 +560,7 @@ def test_alloc_seq_type(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 500 + assert n.extent.value == 500 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -595,7 +595,7 @@ def test_alloc_seq_type2(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 200 + assert n.extent.value == 200 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -629,7 +629,7 @@ def test_reuse_small_buffer(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 800 + assert n.extent.value == 800 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -670,7 +670,7 @@ def compute(a, b): def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == 268435456 + assert n.extent.value == 268435456 tvm.tir.stmt_functor.post_order_visit(stmt, verify) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index ffdf4b5916c4..c21ac22862c7 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -69,8 +69,8 @@ def ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", 1) - local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") - shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + local = ib.allocate(A.dtype, 8, name="buf_local", scope="local") + shared = ib.allocate(A.dtype, 8, name="buf_shared", scope="shared") with ib.for_range(0, 8) as i: with ib.if_scope(Aptr[i] < 0): diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index b511118f8b52..4989742dcec7 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -14,13 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import pytest + import tvm from tvm import te -import os -def test_unroll_loop(): +@tvm.testing.fixture +def loop_module(): ib = tvm.tir.ir_builder.create() + dtype = "int64" n = te.size_var("n") Ab = tvm.tir.decl_buffer((n,), dtype) @@ -31,41 +37,58 @@ def test_unroll_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) - assert isinstance(stmt, tvm.tir.For) + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + +def test_auto_unroll_disabled_above_limit(loop_module): with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert not isinstance(ret, tvm.tir.For) + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert not isinstance(body, tvm.tir.For) + +def test_auto_unroll_enabled_below_limit(loop_module): with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 15}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert isinstance(body, tvm.tir.For) + +def test_explicit_unroll(loop_module): with tvm.transform.PassContext( config={"tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) - assert ret.kind == tvm.tir.ForKind.UNROLLED + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert isinstance(body, tvm.tir.For) + assert body.kind == tvm.tir.ForKind.UNROLLED + +@tvm.testing.fixture +def loop_module_pragma_sequential(loop_module): + orig_body = loop_module["main"].body ib = tvm.tir.ir_builder.create() ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16) - ib.emit(stmt) + ib.emit(orig_body) wrapped = ib.get() - wrapped = tvm.tir.SeqStmt([wrapped, stmt]) - assert isinstance(ret, tvm.tir.For) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped)) + body = tvm.tir.SeqStmt([wrapped, orig_body]) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(loop_module["main"].params, body)) + return mod + + +def test_pragma_unroll(loop_module_pragma_sequential): with tvm.transform.PassContext( config={"tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.For) - assert ret[0].kind == tvm.tir.ForKind.UNROLLED - assert isinstance(ret[1], tvm.tir.For) - assert ret[1].kind != tvm.tir.ForKind.UNROLLED + mod = tvm.tir.transform.UnrollLoop()(loop_module_pragma_sequential) + body = mod["main"].body + assert isinstance(body[0], tvm.tir.For) + assert body[0].kind == tvm.tir.ForKind.UNROLLED + assert isinstance(body[1], tvm.tir.For) + assert body[1].kind != tvm.tir.ForKind.UNROLLED def test_unroll_fake_loop(): @@ -89,8 +112,9 @@ def test_unroll_fake_loop(): "tir.UnrollLoop": {"auto_max_depth": 8, "auto_max_extent": 1, "explicit_unroll": False} } ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.Store) + mod = tvm.tir.transform.UnrollLoop()(mod) + body = mod["main"].body + assert isinstance(body[0], tvm.tir.BufferStore) def test_unroll_single_count_loops(): @@ -111,6 +135,4 @@ def test_unroll_single_count_loops(): if __name__ == "__main__": - test_unroll_loop() - test_unroll_fake_loop() - test_unroll_single_count_loops() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index b1e580957b24..5a6e7f682996 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -170,7 +170,7 @@ def test_ir(A, B, C): A = ib.buffer_ptr(A) B = ib.buffer_ptr(B) C = ib.buffer_ptr(C) - i = ib.allocate("int32", (1,), name="i", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") i[0] = 0 with ib.for_range(0, n) as j: diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 99a22636b927..076d707e361b 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -147,7 +147,7 @@ def test_no_body(): def allocate_with_buffers() -> None: - with T.allocate([1], "float32", "") as [A, B]: # error + with T.allocate(1, "float32", "") as [A, B]: # error T.evaluate(1.0) @@ -384,7 +384,7 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: with T.block([], "root"): - B = T.allocate([256], "float32", "global") + B = T.allocate(256, "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8058b96b024d..020cacff0a79 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -93,7 +93,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - packedB = T.allocate([32768], "float32x32", "global") + packedB = T.allocate(32768, "float32x32", "global") for x in T.parallel(0, 32): for y in T.serial(0, 1024): T.store( @@ -108,7 +108,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: T.broadcast(True, 32), ) for x_outer in T.parallel(0, 32): - C_global = T.allocate([1024], "float32", "global") + C_global = T.allocate(1024, "float32", "global") for y_outer in T.serial(0, 32): for x_c_init in T.serial(0, 32): T.store( @@ -1080,11 +1080,11 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ty = T.env_thread("threadIdx.y") tz = T.env_thread("threadIdx.z") T.launch_thread(bz, 196) - Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = T.allocate([12288], "float16", "shared") - W_shared = T.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + Conv_wmma_accumulator = T.allocate(2048, "float32", "wmma.accumulator") + Apad_shared = T.allocate(12288, "float16", "shared") + W_shared = T.allocate(12288, "float16", "shared") + Apad_shared_wmma_matrix_a = T.allocate(512, "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = T.allocate(1024, "float16", "wmma.matrix_b") T.launch_thread(bx, 2) T.launch_thread(by, 4) T.launch_thread(ty, 4) @@ -2653,7 +2653,7 @@ def vthread_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B = T.allocate(16, "float32", "local") for j in range(16): B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + T.float32(1) for j in range(16): @@ -3067,7 +3067,7 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + tensor_2 = T.allocate(200704, "uint8", "global", annotations={"attr1_key": "attr1_value"}) for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 383841f19e34..40672b24516b 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -172,7 +172,7 @@ def _post_order(op): ), op.body, ) - alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) + alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extent, op.condition, let_stmt) del rw_info[buffer_var] return alloc if isinstance(op, tvm.tir.Load): @@ -226,7 +226,7 @@ def _merge_block(slist, body): if op.body == body: body = op elif isinstance(op, tvm.tir.Allocate): - body = tvm.tir.Allocate(op.buffer_var, op.dtype, op.extents, op.condition, body) + body = tvm.tir.Allocate(op.buffer_var, op.dtype, op.extent, op.condition, body) elif isinstance(op, tvm.tir.AttrStmt): body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value, body) elif isinstance(op, tvm.tir.For):