diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index b95d575360e6..ee8032236252 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -75,8 +75,6 @@ class PrimFuncFrameNode : public TIRFrameNode { Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ Map buffer_map; - /*! \brief The buffer map prior to flattening. */ - Map preflattened_buffer_map; /*! \brief Additional attributes storing the meta-data */ Optional> attrs; /*! \brief The variable map bound to thread env. */ @@ -90,7 +88,6 @@ class PrimFuncFrameNode : public TIRFrameNode { v->Visit("args", &args); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); - v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("env_threads", &env_threads); v->Visit("root_alloc_buffers", &root_alloc_buffers); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index d9e1a1b49063..5cba87920580 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -114,26 +114,6 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = Data int align = -1, int offset_factor = 0, String buffer_type = "default", Array axis_separators = {}); -/*! - * \brief The pre-flattened buffer statement. - * \param postflattened_buffer The original buffer to be flattened. - * \param shape The type of the buffer prior to flattening. - * \param dtype The data type in the content of the buffer. - * \param data The pointer to the head of the data. - * \param strides The strides of each dimension. - * \param elem_offset The offset in terms of number of dtype elements (including lanes). - * \param storage_scope The optional storage scope of buffer data pointer. - * \param align The alignment requirement of data pointer in bytes. - * \param offset_factor The factor of elem_offset field. - * \param buffer_type The buffer type. - * \param axis_separators The separators between input axes when generating flattened output axes. - */ -void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, - DataType dtype = DataType::Float(32), Optional data = NullOpt, - Array strides = {}, PrimExpr elem_offset = PrimExpr(), - String storage_scope = "global", int align = -1, int offset_factor = 0, - String buffer_type = "default", Array axis_separators = {}); - /*! * \brief The block declaration statement. * \param name The name of the block. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index d793d84fc677..cf92f97360b1 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode { * While we could have express parameter unpacking and constraint using * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. - */ - Map buffer_map; - - /*! \brief The buffer map prior to flattening. - * - * This contains the buffers as they exists prior to flattening, and - * is used for validating an input tensor passed into the packed - * API. Any buffer that is present in `buffer_map` but not present - * in `preflattened_buffer_map` is assumed to be the same before - * and after flattening (e.g. a 1-d tensor that is backed by 1-d - * flat memory). * - * TODO(Lunderberg): Remove preflattened_buffer_map, and instead - * declare each flattened buffer as aliasing the original tensor - * shape. This should include improving the StmtExprMutator to - * provide easier interactions with Buffer objects, so that the - * bookkeeping of relationships between buffers doesn't need to be - * repeated across several transforms. + * Prior to buffer flattening, which is performed either in + * StorageFlatten for TE-based schedules or in FlattenBuffer for + * TIR-based schedules, these buffer objects are used directly in + * the body of the function. After buffer flattening, these buffer + * objects remain unflattened for use in argument validation, but + * all usage in the body of the function is done through a + * flattened alias of the buffer. */ - Map preflattened_buffer_map; + Map buffer_map; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); - v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -123,7 +112,6 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && - equal(preflattened_buffer_map, other->preflattened_buffer_map) && equal(ret_type, other->ret_type) && equal(body, other->body) && equal(attrs, other->attrs); } @@ -131,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(params); hash_reduce(buffer_map); - hash_reduce(preflattened_buffer_map); hash_reduce(ret_type); hash_reduce(body); hash_reduce(attrs); @@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc { * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally * generated as a tensor of shape ``[32, 32]``) * - * \param preflattened_buffer_map The buffer map for - * parameter buffer unpacking. This contains buffer - * objects as they are expected to be passed in by the - * callee. (e.g. a buffer of shape ``[32, 32]`` originally - * generated as a tensor of shape ``[32, 32]``) - * * \param attrs Additional function attributes. * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc( - Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), - Optional> preflattened_buffer_map = Optional>(), - DictAttrs attrs = NullValue(), Span span = Span()); + TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), + Map buffer_map = Map(), + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index cc94c6e816cd..e15d126dd969 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -299,7 +299,6 @@ def _ftransform(f, mod, ctx): new_body, f.ret_type, new_buffer_map, - f.preflattened_buffer_map, f.attrs, f.span, ) @@ -327,7 +326,7 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - def collect_encoding_definitions(stmt, old_buffer_to_const): + def collect_encoding_definitions(stmt, old_buffer_var_to_const): # Map from copy destination to copy source. copy_map = {} # List of buffer copies that occurred @@ -376,7 +375,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): """Encode the weights or align the bias either for one or two cores, depending on the variant.""" - constant = old_buffer_to_const[buffer1] + constant = old_buffer_var_to_const[buffer1.data] # If we have just one core, encode the whole constant if buffer2 is None: @@ -471,7 +470,12 @@ def _visit(stmt): } def transform_stmt( - stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx + stmt, + buf_remap, + var_remap, + pointer_to_buffer, + new_buffer_var_to_const, + new_buffer_to_split_idx, ): def _visit_rewrite(stmt): if isinstance(stmt, tvm.tir.Call): @@ -485,7 +489,7 @@ def _visit_rewrite(stmt): # encoded buffer, the current should be a length. if ( isinstance(prev_arg, tvm.tir.BufferLoad) - and prev_arg.buffer in new_buffer_to_const + and prev_arg.buffer.data in new_buffer_var_to_const ): buffer_size = np.prod(list(prev_arg.buffer.shape)) arg = buffer_size @@ -554,28 +558,56 @@ def _visit_rewrite(stmt): ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], ) + def _collect_parameter_buffer_aliases(prim_func): + buffer_vars = {} + for param in prim_func.params: + if param in prim_func.buffer_map: + buf = prim_func.buffer_map[param] + buffer_vars[buf.data] = {buf} + + def visit(node): + if isinstance(node, (tvm.tir.BufferStore, tvm.tir.BufferLoad, tvm.tir.DeclBuffer)): + buf = node.buffer + if buf.data in buffer_vars: + buffer_vars[buf.data].add(buf) + + tvm.tir.stmt_functor.post_order_visit(prim_func.body, visit) + return buffer_vars + def _ftransform(f, mod, ctx): + param_buffer_var_usage = _collect_parameter_buffer_aliases(f) + # Step 0: Unpack the constant dictionary in terms of the # functions buffers. - old_buffer_to_const = {} + old_buffer_var_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - old_buffer_to_const[f.buffer_map[param]] = const_dict[i] + old_buffer_var_to_const[f.buffer_map[param].data] = const_dict[i] # Step 1: Collect information on the buffers that will be # replaced by encodings. - buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const) + buffer_information = collect_encoding_definitions(f.body, old_buffer_var_to_const) # Step 2: Generate variable/buffer remaps, based on the # collected information. buf_remap = {} - new_buffer_to_const = {} + new_buffer_var_to_const = {} new_buffer_to_split_idx = {} + def define_remap(old_buf, new_buf): + try: + old_buffers = param_buffer_var_usage[old_buf.data] + except KeyError: + old_buffers = [old_buf] + + for old_buffer in old_buffers: + buf_remap[old_buffer] = new_buf + # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: - buf_remap[info["old_buffer"]] = info["new_buffer"] - new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] + define_remap(info["old_buffer"], info["new_buffer"]) + + new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"] if info["split_idx"]: new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"] @@ -596,9 +628,11 @@ def _ftransform(f, mod, ctx): name=copy_dest.name, scope=copy_dest.scope(), ) - buf_remap[copy_dest] = new_dest - if copy_source in new_buffer_to_const: - new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] + define_remap(copy_dest, new_dest) + if copy_source.data in new_buffer_var_to_const: + new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[ + copy_source.data + ] if copy_source in new_buffer_to_split_idx: new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source] @@ -615,7 +649,7 @@ def _ftransform(f, mod, ctx): buf_remap, var_remap, pointer_to_buffer, - new_buffer_to_const, + new_buffer_var_to_const, new_buffer_to_split_idx, ) @@ -626,10 +660,10 @@ def _ftransform(f, mod, ctx): if buffer in buf_remap: buffer = buf_remap[buffer] - if buffer in new_buffer_to_const: - new_const_dict[i] = new_buffer_to_const[buffer].flatten() - elif buffer in old_buffer_to_const: - new_const_dict[i] = old_buffer_to_const[buffer].flatten() + if buffer.data in new_buffer_var_to_const: + new_const_dict[i] = new_buffer_var_to_const[buffer.data].flatten() + elif buffer.data in old_buffer_var_to_const: + new_const_dict[i] = old_buffer_var_to_const[buffer.data].flatten() new_buffer_map[param] = buffer @@ -638,7 +672,6 @@ def _ftransform(f, mod, ctx): new_body, f.ret_type, new_buffer_map, - f.preflattened_buffer_map, f.attrs, f.span, ) @@ -873,7 +906,6 @@ def CreatePrimFuncWithoutConstants(const_dict): def _ftransform(f, mod, ctx): new_params = list() new_buffer_map = dict() - new_preflattened_buffer_map = dict() for param_idx in const_dict.keys(): # We are using buffer_var to key the constants as # PrimFunc params of constants will be removed. @@ -882,14 +914,11 @@ def _ftransform(f, mod, ctx): if i not in const_dict.keys(): new_params.append(param) new_buffer_map[param] = f.buffer_map[param] - if param in f.preflattened_buffer_map: - new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param] return tvm.tir.PrimFunc( new_params, f.body, f.ret_type, new_buffer_map, - new_preflattened_buffer_map, f.attrs, f.span, ) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 0678925e2f7c..842e21378fd1 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -314,74 +314,6 @@ def match_buffer( ) -def preflattened_buffer( - postflattened: Buffer, - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], - dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: List[int] = None, -) -> None: - """The pre-flattened buffer statement. - - Parameters - ---------- - postflattened : Buffer - The original buffer to be flattened. - - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] - The type of the buffer prior to flattening. - - dtype : str - The data type in the content of the buffer. - - data : Var - The pointer to the head of the data. - - strides : List[PrimExpr] - The strides of each dimension. - - elem_offset : PrimExpr - The offset in terms of number of dtype elements (including lanes). - - scope : str - The optional storage scope of buffer data pointer. - - align : int - The alignment requirement of data pointer in bytes. - - offset_factor : int - The factor of elem_offset field. - - buffer_type : str - The buffer type. - - axis_separators : List[int] - The separators between input axes when generating flattened output axes. - """ - shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - if strides is None: - strides = [] - _ffi_api.PreflattenedBuffer( # type: ignore[attr-defined] # pylint: disable=no-member - postflattened, - shape, - dtype, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - ) - - def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: """The block declaration statement. @@ -1697,7 +1629,6 @@ def f(): "func_attr", "func_ret", "match_buffer", - "preflattened_buffer", "block", "init", "where", diff --git a/python/tvm/script/parser_v1/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py index f7f16855c752..b84b7d398084 100644 --- a/python/tvm/script/parser_v1/context_maintainer.py +++ b/python/tvm/script/parser_v1/context_maintainer.py @@ -129,8 +129,6 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" - func_preflattened_buffer_map: Mapping[Var, Buffer] = {} - """Mapping[Var, Buffer]: The function buffer map, prior to any flattening.""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} @@ -160,7 +158,6 @@ def __init__( # function context self.func_params = [] self.func_buffer_map = {} - self.func_preflattened_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} # parser and analyzer diff --git a/python/tvm/script/parser_v1/parser.py b/python/tvm/script/parser_v1/parser.py index c34aae23453c..ce8c1fe161a3 100644 --- a/python/tvm/script/parser_v1/parser.py +++ b/python/tvm/script/parser_v1/parser.py @@ -501,7 +501,6 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: body, ret_type, buffer_map=self.context.func_buffer_map, - preflattened_buffer_map=self.context.func_preflattened_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, span=tvm_span_from_synr(node.span), ) diff --git a/python/tvm/script/parser_v1/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi index a64eed055ae8..beefaf4c75d7 100644 --- a/python/tvm/script/parser_v1/tir/__init__.pyi +++ b/python/tvm/script/parser_v1/tir/__init__.pyi @@ -117,18 +117,6 @@ def store( ) -> None: ... def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ... -def preflattened_buffer( - buf: Buffer, - shape: Sequence[PrimExpr], - dtype: str = "float32", - data: Optional[Ptr] = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", -) -> Buffer: ... """ Intrinsics - tvm builtin diff --git a/python/tvm/script/parser_v1/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py index 7cbf47441053..f558eb6b7f73 100644 --- a/python/tvm/script/parser_v1/tir/special_stmt.py +++ b/python/tvm/script/parser_v1/tir/special_stmt.py @@ -904,79 +904,6 @@ def func_attr(dict_attr, span): super().__init__(func_attr, def_symbol=False) -@register -class PreflattenedBufferMap(SpecialStmt): - """Special Stmt for declaring the PrimFunc::preflattened_buffer_map - - Example - ------- - .. code-block:: python - A0 = T.match_buffer(A, (48,), dtype="float32") - T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32") - """ - - def __init__(self): - def preflattened_buffer( - postflattened, - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - span=None, - ): - - param = None - for key, value in self.context.func_buffer_map.items(): - if value.same_as(postflattened): - param = key - break - - assert ( - param is not None - ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." - - if data is None: - data = self.context.func_buffer_map[param].data - - buffer_name: str = f"{postflattened.name}_preflatten" - if align != -1: - if isinstance(align, IntImm): - align = align.value - else: - assert isinstance(align, int), f"align: want int or IntImm, got {align!r}" - - if offset_factor != 0: - if isinstance(offset_factor, IntImm): - offset_factor = offset_factor.value - else: - assert isinstance( - offset_factor, int - ), f"offset_factor: want int or IntImm, got {offset_factor!r}" - - preflattened = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - span=span, - ) - - self.context.func_preflattened_buffer_map[param] = preflattened - - super().__init__(preflattened_buffer, def_symbol=False) - - @register class TargetAttrValue(SpecialStmt): """Special Stmt for target attr value. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 4628ae36265f..c5cc922a3e48 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -49,9 +49,6 @@ class PrimFunc(BaseFunc): buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] The buffer binding map. - preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]] - The buffer binding map, prior to any flattening. - attrs: Optional[tvm.Attrs] Attributes of the function, can be None @@ -65,14 +62,12 @@ def __init__( body, ret_type=None, buffer_map=None, - preflattened_buffer_map=None, attrs=None, span=None, ): param_list = [] buffer_map = {} if buffer_map is None else buffer_map - preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map for x in params: x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): @@ -90,7 +85,6 @@ def __init__( body, ret_type, buffer_map, - preflattened_buffer_map, attrs, span, ) # type: ignore @@ -116,7 +110,6 @@ def with_body(self, new_body, span=None): new_body, self.ret_type, self.buffer_map, - self.preflattened_buffer_map, self.attrs, span, ) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index e50559ac10ff..fc3f49d76fae 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -152,16 +152,6 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } - if (op->preflattened_buffer_map.size() != 0) { - // print preflattened_buffer_map - std::vector preflattened_buffer_map_doc; - for (auto& v : op->preflattened_buffer_map) { - preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second)); - } - doc << Doc::Indent(2, Doc::NewLine() - << "preflattened_buffer_map = {" - << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}"); - } doc << PrintBody(op->body); return doc; } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d7a3a406e352..3182c6a5b61c 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1671,26 +1671,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[buf]; body << ")" << Doc::NewLine(); } - // print preflattened buffer map - for (const auto& param : op->params) { - auto pf_buf_it = op->preflattened_buffer_map.find(param); - if (pf_buf_it != op->preflattened_buffer_map.end()) { - const Buffer& preflattened = (*pf_buf_it).second; - - auto buf_it = op->buffer_map.find(param); - ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name - << " with no corresponding post-flatten buffer."; - const Buffer& postflattened = (*buf_it).second; - - // Call Print() without assigning in order to fill memo_buf_decl_. - Print(preflattened); - buf_not_in_headers_.insert(preflattened.get()); - ICHECK(memo_buf_decl_.count(preflattened)); - - body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", " - << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine(); - } - } // print body body << "# body" << Doc::NewLine(); diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index 82393c535c43..2a4dfb84ddcf 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -504,7 +504,7 @@ class AOTMainLowerer : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, DictAttrs(dict_attrs)); } diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 786b3f81a5ae..3c0ab7c16f23 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -803,7 +803,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, DictAttrs(dict_attrs)); } diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index da51e6b762dd..1ea020e884de 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -108,7 +108,7 @@ class RelayToTIRVisitor : public MixedModeMutator { } tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, - Map(), DictAttrs(dict_attrs)); + DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index eb6cf1cce420..ad2b06695cc1 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -152,7 +152,7 @@ class ConvertAddToSubtract : public MixedModeMutator { }; tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, {}, DictAttrs(dict_attrs)); + buffer_map, DictAttrs(dict_attrs)); // Switch to TIRToRuntime hook for testing Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index f48ee52506b4..1e63201a40dd 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -34,7 +34,6 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/AsStmt(stmts), /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, - /*preflattened_buffer_map=*/preflattened_buffer_map, /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); func = tvm::tir::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 78107136d492..822e8e468377 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -58,7 +58,6 @@ PrimFuncFrame PrimFunc() { n->args.clear(); n->ret_type = NullOpt; n->buffer_map.clear(); - n->preflattened_buffer_map.clear(); n->attrs = NullOpt; n->env_threads.clear(); n->root_alloc_buffers.clear(); @@ -137,26 +136,6 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optio return buffer; } -void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, DataType dtype, - Optional data, Array strides, PrimExpr elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators) { - PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); - for (auto const& p : frame->buffer_map) { - if (p.second.same_as(postflattened_buffer)) { - String buffer_name(postflattened_buffer->name + "_preflatten"); - Buffer buffer = - BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset, - storage_scope, align, offset_factor, buffer_type_str, axis_separators); - details::Namer::Name(buffer, buffer_name); - frame->preflattened_buffer_map.Set(p.first, buffer); - return; - } - } - LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name - << " does not exist."; -} - BlockFrame Block(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; @@ -595,7 +574,6 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 32b59ce54b69..d0933e0691dd 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -210,8 +210,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { // Start with a copy of the current prim_func buffer map. Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); - Map new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(), - prim_func->preflattened_buffer_map.end()); bool any_change = false; // For each constrained parameter... @@ -225,23 +223,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { any_change = true; } new_buffer_map.Set(param, new_buffer); - - // Rewrite the pre-flattened buffers to account for constraint. - // This only has an impact if the IRModule being analyzed has - // already been run through the StorageFlatten or FlattenBuffer - // passes. - if (auto opt = prim_func->preflattened_buffer_map.Get(param)) { - Buffer pf_buffer = opt.value(); - if (pf_buffer.same_as(buffer)) { - new_preflattened_buffer_map.Set(param, new_buffer); - } else { - const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device); - if (!new_buffer.same_as(pf_buffer)) { - any_change = true; - } - new_preflattened_buffer_map.Set(param, new_buffer); - } - } } // Make sure we have accounted for all prim_func parameters. CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); @@ -259,8 +240,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (any_change) { return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, - std::move(new_buffer_map), std::move(new_preflattened_buffer_map), - prim_func->attrs, prim_func->span); + std::move(new_buffer_map), prim_func->attrs, prim_func->span); } else { return prim_func; } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index d51ffbf833a4..369c4adc8536 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -152,9 +152,8 @@ class HoistAllocatesMutator : public StmtExprMutator { current_alloc->span); } - PrimFunc new_main_func = - PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map, - main_func->preflattened_buffer_map, main_func->attrs); + PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, + main_func->buffer_map, main_func->attrs); return new_main_func; } @@ -523,7 +522,6 @@ class MergeConstantsMutator : public StmtExprMutator { prim_func_node->body = std::move(new_body); prim_func_node->buffer_map = std::move(new_buffer_map); prim_func_node->params = std::move(new_params); - prim_func_node->preflattened_buffer_map = {}; PrimFunc f{GetRef(prim_func_node)}; // Add the new const dict as an attribute diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index c609ad158e34..d4802e287693 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,9 +29,7 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, - Optional> preflattened_buffer_map, DictAttrs attrs, - Span span) { + Map buffer_map, DictAttrs attrs, Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -42,7 +40,6 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->body = std::move(body); n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); - n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map()); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); n->span = std::move(span); @@ -129,9 +126,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, - Map preflattened_buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span); + Map buffer_map, DictAttrs attrs, Span span) { + return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); TVM_REGISTER_GLOBAL("tir.TensorIntrin") diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 5dc08f31c23c..040c48c79693 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -308,37 +308,8 @@ class BF16LowerRewriter : public StmtExprMutator { } } - // Most passes do not change the preflattened buffer map, nor - // should they change it. This is an exception, because the Var - // associated with the `BufferNode::data` in - // `PrimFunc::buffer_map` may be replaced, and the corresponding - // Var in the `PrimFunc::preflattened_buffer_map` must also be - // replaced. - Map new_preflattened_buffer_map; - for (auto& itr : op->preflattened_buffer_map) { - auto param_var = itr.first; - auto oldbuf = itr.second; - if (oldbuf->dtype.is_bfloat16()) { - auto it = new_buffer_map.find(param_var); - ICHECK(it != new_buffer_map.end()) - << "PrimFunc parameter " << param_var->name_hint - << " is associated with the pre-flattened buffer " << oldbuf->name - << ", but isn't associated with any post-flatten buffer."; - const Buffer& flatbuf = (*it).second; - DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); - auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides, - oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment, - oldbuf->offset_factor, oldbuf->buffer_type); - buffer_remap_[oldbuf] = newbuf; - new_preflattened_buffer_map.Set(param_var, newbuf); - } else { - new_preflattened_buffer_map.Set(param_var, oldbuf); - } - } - if (buffer_remap_.size() != 0) { op->buffer_map = new_buffer_map; - op->preflattened_buffer_map = new_preflattened_buffer_map; } } diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 5441120491c6..d51a44887f54 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -37,22 +37,18 @@ namespace tir { class BufferFlattener : public StmtExprMutator { public: static PrimFunc Flatten(PrimFunc func) { - Map preflattened_buffer_map = - Merge(func->buffer_map, func->preflattened_buffer_map); - auto pass = BufferFlattener(func->buffer_map); + auto pass = BufferFlattener(); auto writer = func.CopyOnWrite(); writer->body = pass.VisitStmt(func->body); - writer->preflattened_buffer_map = preflattened_buffer_map; - writer->buffer_map = pass.updated_extern_buffer_map_; + // The buffers in func->buffer_map are deliberately left + // unflattened, as they are used for validation of user-provided + // arguments. The flattened buffers used in the updated + // function body alias the argument buffers. return func; } private: - explicit BufferFlattener(const Map& extern_buffer_map) { - for (const auto& kv : extern_buffer_map) { - updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second)); - } - } + BufferFlattener() {} Stmt VisitStmt_(const BlockNode* op) final { ICHECK_EQ(op->match_buffers.size(), 0) diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 344e6c7ae3cb..fed76876f6bf 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -74,9 +74,9 @@ class PackedCallLegalizer : public StmtExprMutator { tvm::runtime::Map::iterator param_buf_it; if (prim_func != nullptr) { auto param_var = prim_func->params[i - 1]; - param_buf_it = prim_func->preflattened_buffer_map.find(param_var); + param_buf_it = prim_func->buffer_map.find(param_var); } - if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) { + if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) { Buffer param = (*param_buf_it).second; PrimExpr shape = tvm::tir::Call( DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 5b9bac03aba9..c1611a23a05f 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -209,9 +209,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { continue; } - if (func_ptr->preflattened_buffer_map.count(param)) { - buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]); - } else if (func_ptr->buffer_map.count(param)) { + if (func_ptr->buffer_map.count(param)) { buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); } else { var_def.emplace_back(v_arg, param); diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index db59824bf1ce..90150ebd3cdf 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -52,21 +52,21 @@ class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { Map> buffer_lca = DetectBufferAccessLCA(func); + std::unordered_set arg_buffer_vars; CollectUnmanagedAllocations collector; collector(func->body); unmanaged_allocations_ = collector.unmanaged_allocations; - std::unordered_set arg_buffers; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; - arg_buffers.emplace(buffer.get()); + arg_buffer_vars.emplace(buffer->data.get()); buffer_data_to_buffer_.Set(buffer->data, buffer); } // create buffers to be allocated at each stmts for (const auto& kv : buffer_lca) { const Buffer& buffer = kv.first; const StmtNode* stmt = kv.second.get(); - if (arg_buffers.count(buffer.get())) { + if (arg_buffer_vars.count(buffer->data.get())) { continue; } if (!unmanaged_allocations_.count(buffer->data.get())) { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index ab1b062ad647..eb0409e555a1 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -402,6 +402,7 @@ class BufferStrideLegalize : public StmtExprMutator { auto fptr = func.CopyOnWrite(); fptr->body = pass(std::move(fptr->body)); + fptr->buffer_map = pass.UpdatedExternBufferMap(); if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); } @@ -420,7 +421,6 @@ class BufferStrideLegalize : public StmtExprMutator { BufferEntry entry; entry.remap_to = with_strides; entry.in_scope = true; - entry.is_external = true; buf_map_[buf] = entry; } updated_extern_buffer_map_.Set(kv.first, with_strides); @@ -443,51 +443,54 @@ class BufferStrideLegalize : public StmtExprMutator { Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } Buffer WithStrides(Buffer buf) { - auto it = buf_map_.find(buf); + auto cache_key = buf; + + auto it = buf_map_.find(cache_key); if (it != buf_map_.end()) { const BufferEntry& entry = it->second; ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer"; return entry.remap_to; } + Array shape = buf->shape; + if (buf->strides.size()) { ICHECK_EQ(buf->strides.size(), buf->shape.size()) << "Buffer " << buf << " has inconsistent strides/shape."; - return buf; - } - - // Keeping this to have matched behavior to previous version. - // There are many parts of the codebase that assume that a strided - // array cannot be compact. For example, ArgBinder::BindBuffer - // and tir.Specialize. - if (dim_align_.count(buf) == 0) { - return buf; - } - - // Can't define the strides for a buffer without a known shape. - Array shape = buf->shape; - if (shape.size() == 0) { - return buf; - } - - std::vector rstrides; - const std::vector& avec = dim_align_[buf]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = bound_analyzer_->Simplify(stride); + } else if (dim_align_.count(buf) == 0) { + // Keeping this to have matched behavior to previous version. + // There are many parts of the codebase that assume that a + // strided array cannot be compact. For example, + // ArgBinder::BindBuffer and tir.Specialize. To avoid breaking + // these, do not define the strides unless required for a + // non-compact array. + } else if (shape.size() == 0) { + // Can't define the strides for a buffer without a known shape. + } else { + // With everything checked, can now define the updated strides + std::vector rstrides; + const std::vector& avec = dim_align_[buf]; + int first_dim = 0; + PrimExpr stride = make_const(shape[first_dim].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (dim < avec.size() && avec[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + stride = bound_analyzer_->Simplify(stride); + } + rstrides.push_back(stride); + stride = stride * shape[dim]; } - rstrides.push_back(stride); - stride = stride * shape[dim]; + + buf.CopyOnWrite()->strides = Array(rstrides.rbegin(), rstrides.rend()); } - auto ptr = buf.CopyOnWrite(); - ptr->strides = Array(rstrides.rbegin(), rstrides.rend()); + BufferEntry entry; + entry.remap_to = buf; + entry.in_scope = true; + buf_map_[cache_key] = entry; return buf; } @@ -513,16 +516,10 @@ class BufferStrideLegalize : public StmtExprMutator { Buffer target_with_strides = WithStrides(Downcast(arr[1])); Buffer source_with_strides = WithStrides(source); - { - BufferEntry entry; - entry.remap_to = source_with_strides; - entry.in_scope = true; - entry.is_external = false; - buf_map_[source] = entry; - } - Stmt body = this->VisitStmt(op->body); + buf_map_[source].in_scope = false; + return AttrStmt(Array{source_with_strides, target_with_strides}, op->attr_key, op->value, body, op->span); } else { @@ -560,13 +557,6 @@ class BufferStrideLegalize : public StmtExprMutator { Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); - { - BufferEntry entry; - entry.remap_to = with_strides; - entry.in_scope = true; - entry.is_external = false; - buf_map_[key] = entry; - } Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -589,22 +579,14 @@ class BufferStrideLegalize : public StmtExprMutator { template Node VisitBufferAccess(Node node) { - auto alloc_key = node->buffer->data.get(); - if (!buf_map_.count(node->buffer) && buffer_var_defines_.count(alloc_key)) { - BufferEntry entry; - entry.remap_to = WithStrides(node->buffer); - entry.in_scope = true; - entry.is_external = false; - buf_map_[node->buffer] = entry; - } - auto it = buf_map_.find(node->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << node->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot access a buffer " << node->buffer->name << ", out of scope"; + ICHECK(it == buf_map_.end() || it->second.in_scope) + << "Cannot access a buffer " << node->buffer->name << ", out of scope"; - auto writer = node.CopyOnWrite(); - writer->buffer = e.remap_to; + auto with_strides = WithStrides(node->buffer); + if (!with_strides.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = with_strides; + } return node; } @@ -623,7 +605,6 @@ class BufferStrideLegalize : public StmtExprMutator { struct BufferEntry { Buffer remap_to; bool in_scope; - bool is_external; }; std::unordered_map buf_map_; @@ -846,6 +827,7 @@ class BufferBindUnwrapper : public StmtExprMutator { BufferEntry e; e.buffer = kv.second; e.external = true; + var_to_buffer_[kv.second->data.get()] = kv.second; buf_map_[kv.second.get()] = std::move(e); } } @@ -1001,6 +983,7 @@ class BufferBindUnwrapper : public StmtExprMutator { BufferEntry e; e.bounds = op->bounds; e.buffer = op->buffer; + var_to_buffer_[op->buffer->data.get()] = op->buffer; buf_map_[key] = std::move(e); } @@ -1089,6 +1072,7 @@ class BufferBindUnwrapper : public StmtExprMutator { source_info.buffer = source; source_info.remap = std::make_unique(remap); + var_to_buffer_[source->data.get()] = source; buf_map_[source.get()] = std::move(source_info); } @@ -1160,18 +1144,70 @@ class BufferBindUnwrapper : public StmtExprMutator { }; const BufferEntry& GetBufferEntry(Buffer buffer) { - auto alloc_key = buffer->data.get(); - if (!buf_map_.count(buffer.get()) && buffer_var_defines_.count(alloc_key)) { + if (buf_map_.count(buffer.get())) { + const BufferEntry& e = buf_map_[buffer.get()]; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return e; + } else if (buffer_var_defines_.count(buffer->data.get())) { + // The buffer var was defined, but the buffer hasn't been seen + // before. BufferEntry entry; entry.buffer = buffer; + var_to_buffer_[buffer->data.get()] = buffer; buf_map_[buffer.get()] = std::move(entry); - } + return buf_map_[buffer.get()]; + } else if (var_remap_.count(buffer->data.get())) { + // The buffer var is an alias of a bound buffer. Only + // supported if the bound buffer has no offsets. In this + // case, we just need to make a new aliasing buffer that + // shares the remapped data variable. + Var old_var = buffer->data; + Var new_var = Downcast(var_remap_[old_var.get()]); - auto it = buf_map_.find(buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; - return it->second; + { + ICHECK(var_to_buffer_.count(old_var.get())) + << "Cannot find remap information for aliased buffer var " << old_var->name_hint + << ", required to verify this alias is legal."; + const Buffer& aliased_buffer = var_to_buffer_[old_var.get()]; + const BufferEntry& entry = buf_map_[aliased_buffer.get()]; + if (entry.remap) { + for (const auto& begin : entry.remap->begins) { + ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported"; + } + } + } + + { + Buffer new_buf = buffer; + new_buf.CopyOnWrite()->data = new_var; + + RemapInfo remap_info; + remap_info.target = new_buf; + remap_info.begins = Array(buffer->shape.size(), 0); + remap_info.extents = buffer->shape; + + BufferEntry entry; + entry.buffer = buffer; + entry.remap = std::make_unique(remap_info); + entry.in_scope = true; + var_to_buffer_[buffer->data.get()] = buffer; + buf_map_[buffer.get()] = std::move(entry); + } + return buf_map_[buffer.get()]; + } else if (var_to_buffer_.count(buffer->data.get())) { + // This buffer is an alias of a known buffer, with no remaps. A + // buffer entry should be generated and returned. + BufferEntry entry; + entry.buffer = buffer; + entry.in_scope = true; + var_to_buffer_[buffer->data.get()] = buffer; + buf_map_[buffer.get()] = std::move(entry); + + return buf_map_[buffer.get()]; + } else { + LOG(FATAL) << "Can't work around the undefined buffer"; + return *static_cast(nullptr); + } } // The buffer assignment map @@ -1181,6 +1217,9 @@ class BufferBindUnwrapper : public StmtExprMutator { std::unordered_set illegal_vars_; // Buffer map std::unordered_map buf_map_; + // Map from Var to the Buffer they occurred in. In case of aliased + // buffers, contains the first buffer. + std::unordered_map var_to_buffer_; // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. std::unordered_set buffer_var_defines_; @@ -1311,13 +1350,12 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - Map preflattened_buffer_map = - Merge(func->buffer_map, func->preflattened_buffer_map); - auto fptr = func.CopyOnWrite(); fptr->body = pass(std::move(fptr->body)); - fptr->preflattened_buffer_map = preflattened_buffer_map; - fptr->buffer_map = pass.UpdatedBufferMap(); + // The buffers in func->buffer_map are deliberately left + // unflattened, as they are used for validation of user-provided + // arguments. The flattened buffers used in the updated + // function body alias the argument buffers. return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1345,15 +1383,12 @@ class StorageFlattener : public StmtExprMutator { } } e.external = true; + buffer_var_defines_.insert(kv.second->data.get()); buf_map_[kv.second] = e; - - updated_extern_buffer_map_.Set(kv.first, e.flattened_buffer); } cache_line_size_ = cache_line_size; } - Map UpdatedBufferMap() { return updated_extern_buffer_map_; } - Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); @@ -1512,8 +1547,10 @@ class StorageFlattener : public StmtExprMutator { writer->dtype = DataType::Int(8); } + buffer_var_defines_.insert(op->buffer->data.get()); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); + buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; Stmt ret = @@ -1777,8 +1814,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map> buffer_var_map_; // Buffer map std::unordered_map buf_map_; - // The extern buffer map, updated to include flattened buffers. - Map updated_extern_buffer_map_; // Collects shapes. std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 0671f1ea2722..2bded7b4877b 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -166,8 +166,8 @@ IRModule PoolInfoAssigner::operator()() { if (kv.second->IsInstance()) { func_ = Downcast(kv.second); Stmt body = this->VisitStmt(func_->body); - PrimFunc new_prim_func = PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, - func_->preflattened_buffer_map, func_->attrs); + PrimFunc new_prim_func = + PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); mod_->Update(gv, new_prim_func); } } diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 56aba654b59e..439e2643380a 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -242,8 +242,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, - si.buffer_map, original_attrs); + PrimFunc ret = + PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } @@ -449,12 +449,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() { // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (!emit_tvmscript_printable_) { - main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, - main_func->attrs); + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); } else { main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, DictAttrs()); + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); } module_->Update(gv, main_func); if (!emit_tvmscript_printable_) { diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc index 59eee961632d..cf754131776c 100644 --- a/src/tir/usmp/transform/create_io_allocates.cc +++ b/src/tir/usmp/transform/create_io_allocates.cc @@ -195,9 +195,8 @@ IRModule IOAllocateCreator::operator()() { } } const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); - mod_->Update(gv, - PrimFunc(new_main_params, main_body, main_func_->ret_type, main_func_->buffer_map, - main_func_->preflattened_buffer_map, main_func_->attrs, main_func_->span)); + mod_->Update(gv, PrimFunc(new_main_params, main_body, main_func_->ret_type, + main_func_->buffer_map, main_func_->attrs, main_func_->span)); return mod_; } diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index c751d44b6156..61128da71c37 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,9 +34,11 @@ @tvm.script.ir_module class WeightStreamOnlyU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) buffer1 = T.buffer_decl([160], "uint8") buffer3 = T.buffer_decl([144], "uint8") buffer5 = T.buffer_decl([144], "uint8") @@ -62,10 +64,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), @tvm.script.ir_module class WeightStreamOnlyU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) buffer_encoded_1 = T.buffer_decl([192], dtype="uint8") buffer_encoded_2_1 = T.buffer_decl([192], dtype="uint8") buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8") @@ -148,10 +152,12 @@ def _get_func(): @tvm.script.ir_module class RereadWeightsU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([384], "uint8") + placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) # body p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([384], "uint8", data=p1_data) @@ -167,10 +173,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), @tvm.script.ir_module class RereadWeightsU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) placeholder_encoded_1 = T.buffer_decl([464], "uint8") # body p1_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) @@ -246,13 +254,15 @@ def _get_func(): @tvm.script.ir_module class DirectReadOnlyU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([592], "uint8") buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([160], "uint8") buffer_3 = T.buffer_decl([80], "uint8") + placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) @@ -264,7 +274,7 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), @tvm.script.ir_module class DirectReadOnlyU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition @@ -272,6 +282,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) ethosu_write_2 = T.buffer_decl([4096], "int8", data=ethosu_write_2_data) @@ -340,7 +352,7 @@ def _get_func(): @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([112], "uint8") @@ -349,6 +361,8 @@ def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"] buffer7 = T.buffer_decl([112], "uint8") buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") + ifm = T.buffer_decl([8192], "int8", data=input_ifm.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) # body p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([112], "uint8", data=p1_data) @@ -371,11 +385,12 @@ def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"] @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition + ifm = T.buffer_decl([8192], dtype="int8", data=input_ifm.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) buffer1 = T.buffer_decl([128], dtype="uint8") buffer2 = T.buffer_decl([128], dtype="uint8") buffer3 = T.buffer_decl([128], dtype="uint8") diff --git a/tests/python/contrib/test_ethosu/test_hoist_allocates.py b/tests/python/contrib/test_ethosu/test_hoist_allocates.py index 6c6d51fa06b9..1508aa441c3b 100644 --- a/tests/python/contrib/test_ethosu/test_hoist_allocates.py +++ b/tests/python/contrib/test_ethosu/test_hoist_allocates.py @@ -106,15 +106,15 @@ def test_double_convolution(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(3402,), "int8"], placeholder_encoded: T.Buffer[(128,), "uint8"], placeholder_encoded_1: T.Buffer[(32,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"], placeholder_encoded_3: T.Buffer[(32,), "uint8"], ethosu_write: T.Buffer[(3402,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 27, 42, 3), "int8"], input_placeholder_encoded: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_1: T.Buffer[(3, 10), "uint8"], input_placeholder_encoded_2: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_3: T.Buffer[(3, 10), "uint8"], input_ethosu_write: T.Buffer[(1, 27, 42, 3), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 27, 42, 3], dtype="int8", data=placeholder.data) - T.preflattened_buffer(placeholder_encoded, [3, 3, 2, 3], dtype="int8") - T.preflattened_buffer(placeholder_encoded_1, [3, 10], dtype="uint8") - T.preflattened_buffer(placeholder_encoded_2, [3, 3, 2, 3], dtype="int8") - T.preflattened_buffer(placeholder_encoded_3, [3, 10], dtype="uint8") - T.preflattened_buffer(ethosu_write, [1, 27, 42, 3], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([3402], dtype="int8", data=input_placeholder.data) + placeholder_encoded = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded.data) + placeholder_encoded_1 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_1.data) + placeholder_encoded_2 = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded_2.data) + placeholder_encoded_3 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_3.data) + ethosu_write = T.buffer_decl([3402], dtype="int8", data=input_ethosu_write.data) # body placeholder_global_data = T.allocate([128], "uint8", "global") placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) @@ -150,11 +150,10 @@ def test_identities(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(24,), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 2, 3, 4), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 2, 3, 4], dtype="int8", data=placeholder.data) - T.preflattened_buffer(T_concat, [24], dtype="int8", data=T_concat.data) + placeholder = T.buffer_decl([24], dtype="int8", data=input_placeholder.data) # body ethosu_write_data = T.allocate([12], "int8", "global") ethosu_write = T.buffer_decl([12], "int8", data=ethosu_write_data) @@ -188,11 +187,11 @@ def test_outer_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body with T.allocate([128], "uint8", "global") as placeholder_global_data: placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) @@ -238,11 +237,11 @@ def test_allocate_without_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body placeholder_global_data = T.allocate([128], "uint8", "global") placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index a5adcfceac83..ed1927b849d6 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -399,12 +399,12 @@ def test_read_from_the_same_buffer(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -419,9 +419,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -446,12 +449,12 @@ def test_arbitrary_argument_order(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -473,9 +476,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -509,12 +515,12 @@ def test_arbitrary_argument_order_const_split(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -536,9 +542,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -572,12 +581,12 @@ def test_arbitrary_argument_order_const_split_mixed(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -599,9 +608,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index e6414c24d4a3..379a35b1b4a4 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -30,9 +30,14 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,), "int8"], T_concat: T.Buffer[(4096,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,8,12,16), "int8"], input_placeholder_1: T.Buffer[(1,8,10,16), "int8"], input_T_concat: T.Buffer[(1,8,32,16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + placeholder = T.buffer_decl(1536, dtype="int8", data=input_placeholder.data) + placeholder_1 = T.buffer_decl(1280, dtype="int8", data=input_placeholder_1.data) + T_concat = T.buffer_decl(4096, dtype="int8", data=input_T_concat.data) + buffer = T.buffer_decl([2992], "uint8") buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([2992], "uint8") diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index ae46057369e0..46c6976567c8 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -366,13 +366,15 @@ def _visit(stmt): @tvm.script.ir_module class Conv2dDoubleCascade1: @T.prim_func - def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([304], "uint8") buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([160], "uint8") + placeholder_5 = T.buffer_decl([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([512], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) ethosu_write_2 = T.buffer_decl([1024], "int8", data=ethosu_write_2_data) @@ -386,13 +388,15 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([80], "uint8") buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") + placeholder_5 = T.buffer_decl([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([512], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) ethosu_write_2 = T.buffer_decl([1536], "int8", data=ethosu_write_2_data) @@ -406,13 +410,16 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, @tvm.script.ir_module class Conv2dDoubleCascade3: @T.prim_func - def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640,), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1744], "uint8") buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([880], "uint8") + placeholder_5 = T.buffer_decl([768], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([640], 'int8', data=input_ethosu_write_1.data) + # body ethosu_write_2_data = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) ethosu_write_2 = T.buffer_decl([2560], "int8", data=ethosu_write_2_data) @@ -428,13 +435,15 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, @tvm.script.ir_module class Conv2dDoubleCascade4: @T.prim_func - def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1456], "uint8") buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([11040], "uint8") + placeholder_5 = T.buffer_decl([1024], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([2048], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) ethosu_write_2 = T.buffer_decl((2304,), "int8", data=ethosu_write_2_data) @@ -448,13 +457,15 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 @tvm.script.ir_module class Conv2dDoubleCascade5: @T.prim_func - def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") + placeholder = T.buffer_decl([192], 'int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl([8192], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) @@ -468,13 +479,15 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), @tvm.script.ir_module class Conv2dDoubleCascade6: @T.prim_func - def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], input_ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1456], "uint8") buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([11040], "uint8") buffer_3 = T.buffer_decl([272], "uint8") + placeholder = T.buffer_decl([1024], 'int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl([32768], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) ethosu_write_1 = T.buffer_decl([12288], "int8", data=ethosu_write_1_data) @@ -630,11 +643,13 @@ def _get_func( @tvm.script.ir_module class Conv2dInlineCopy1: @T.prim_func - def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([848], "uint8") buffer_1 = T.buffer_decl([160], "uint8") + placeholder_3 = T.buffer_decl([960], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([1024], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -643,11 +658,13 @@ def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024 @tvm.script.ir_module class Conv2dInlineCopy2: @T.prim_func - def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], input_ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([656], "uint8") + placeholder_3 = T.buffer_decl([315], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([240], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -685,11 +702,13 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): @tvm.script.ir_module class Conv2dInlineReshape1: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") + placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -699,11 +718,13 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape2: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(1, 24, 8), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") + placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -713,11 +734,13 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape3: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(192, 1), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") + placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -727,11 +750,12 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape4: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") + ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 8c7ff35272ef..7da3d7e5be82 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,10 +31,12 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer_1 = T.buffer_decl([384], "uint8") + placeholder_3 = T.buffer_decl([8192], dtype="int8", data=input_placeholder_3.data) + ethosu_write_1 = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write_1.data) # body placeholder_global_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_global = T.buffer_decl([384], "uint8", data=placeholder_global_data) @@ -73,11 +75,13 @@ def _get_func(): @tvm.script.ir_module class WeightStream: @T.prim_func - def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([528], "uint8") buffer_2 = T.buffer_decl([336], "uint8") + placeholder_5 = T.buffer_decl([8192], dtype="int8", data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([4096], dtype="int8", data=input_ethosu_write_1.data) # body placeholder_d_global_data = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.buffer_decl([528], "uint8", data=placeholder_d_global_data) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 254abab644a2..fd1e1afa60d9 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -180,8 +180,10 @@ def test_schedule_cache_reads(): @tvm.script.ir_module class DiamondGraphTir: @T.prim_func - def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 56, 56, 96), "int8"], input_ethosu_write: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([301056], dtype='int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl([75264], dtype='int8', data=input_ethosu_write.data) buffer1 = T.buffer_decl([2848], "uint8") buffer3 = T.buffer_decl([976], "uint8") p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py index fb41e99a9bcb..4aa12aedf215 100755 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -304,7 +304,7 @@ def uses_unsupported_physical_dimensions( # pylint: disable=invalid-name def test_param_shapes(self, ir_module, transformed_input_shape, transformed_output_shape): func = ir_module["main"] primfunc_input_shape, primfunc_output_shape = [ - list(func.preflattened_buffer_map[param].shape) for param in func.params + list(func.buffer_map[param].shape) for param in func.params ] assert primfunc_input_shape == transformed_input_shape assert primfunc_output_shape == transformed_output_shape diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index cd0114d46428..106e0f52adac 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -26,15 +26,12 @@ class Module: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.Buffer[(1,), "float32"], ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: @@ -59,15 +56,12 @@ def tir_packed_call() -> None: class Expected: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.Buffer[(1,), "float32"], ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 3641f06ab8a2..9f7eee096362 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -30,18 +30,6 @@ def scalar_func(a: T.handle, b: T.handle): A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1] -@T.prim_func -def vector_func(a: T.handle, b: T.handle): - n = T.var("int32") - m = 128 - A = T.match_buffer(a, (n, m)) - B = T.match_buffer(b, (n, m)) - - for i in T.serial(n): - for j in T.vectorized(m): - A[i, j] = A[i, j] + B[i, j] - - def test_domain_touched(): func = scalar_func a, b = [func.buffer_map[var] for var in func.params] @@ -81,7 +69,17 @@ def test_domain_touched(): def test_domain_touched_vector(): - func = tvm.lower(vector_func)["main"] + m = tvm.runtime.convert(128) + + @T.prim_func + def func(a: T.handle, b: T.handle): + n = T.var("int32") + A = T.match_buffer(a, (n * m,)) + B = T.match_buffer(b, (n * m,)) + + for i in T.serial(n): + A[i * m : (i + 1) * m : 1] = A[i * m : (i + 1) * m : 1] + B[i * m : (i + 1) * m : 1] + a, b = [func.buffer_map[var] for var in func.params] assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 4140e7732d7e..3f435366e176 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -203,20 +203,20 @@ def test_gpu_feature(): @T.prim_func def tir_matmul( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(256, 256), "float32"], + B: T.Buffer[(256, 256), "float32"], + C: T.Buffer[(256, 256), "float32"], ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data) - T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data) - T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data) + A_flat = T.buffer_decl([16384], dtype="float32", data=A.data) + B_flat = T.buffer_decl([16384], dtype="float32", data=B.data) + C_flat = T.buffer_decl([16384], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = T.float32(0) + C_flat[x * 128 + y] = T.float32(0) for k in T.serial(128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] def test_primfunc_without_lowering(): diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index bd820b617c2d..665697b84be9 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -54,40 +54,44 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: class LoweredModule: @T.prim_func def main( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], data=A.data) - T.preflattened_buffer(B, [128, 128], data=B.data) - T.preflattened_buffer(C, [128, 128], data=C.data) + A_flat = T.buffer_decl([16384], data=A.data) + B_flat = T.buffer_decl([16384], data=B.data) + C_flat = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = 0.0 + C_flat[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = ( + C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] + ) @tvm.script.ir_module class LoweredTIRModule: @T.prim_func def main( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], data=A.data) - T.preflattened_buffer(B, [128, 128], data=B.data) - T.preflattened_buffer(C, [128, 128], data=C.data) + A_flat = T.buffer_decl([16384], data=A.data) + B_flat = T.buffer_decl([16384], data=B.data) + C_flat = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = 0.0 + C_flat[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = ( + C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] + ) def test_lower_build_te_schedule(): diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 870208499e7a..513e04dc2090 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -40,9 +40,9 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for j in T.serial(0, 16): C[i, j] = B_new[0, j] * 2.0 - def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]): + A = T.buffer_decl(256, dtype="float32", data=input_A.data) + C = T.buffer_decl(256, dtype="float32", data=input_C.data) for i in T.serial(0, 16): B_new_data = T.allocate([16], "float32", scope="global") B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data) @@ -71,9 +71,9 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for j in T.serial(0, 16): C[i, j] = B_new[0, j] * 2.0 - def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]): + A = T.buffer_decl(256, dtype="float32", data=input_A.data) + C = T.buffer_decl(256, dtype="float32", data=input_C.data) for i in T.serial(0, 16): B_new_data = T.allocate([16], "float32", "global") B_new = T.buffer_decl(16, "float32", data=B_new_data) @@ -100,9 +100,9 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for j in range(0, 16): C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 - def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]): + A = T.buffer_decl(256, dtype="float32", data=input_A.data) + C = T.buffer_decl(256, dtype="float32", data=input_C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -134,10 +134,10 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: C[i, j] = B[j] * 2.0 def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, n * m, "float32") - C = T.match_buffer(c, n * m, "float32") - T.preflattened_buffer(A, (n, m), "float32", data=A.data) - T.preflattened_buffer(C, (n, m), "float32", data=C.data) + input_A = T.match_buffer(a, (n, m), "float32") + input_C = T.match_buffer(c, (n, m), "float32") + A = T.buffer_decl(n * m, "float32", data=input_A.data) + C = T.buffer_decl(n * m, "float32", data=input_C.data) for i in range(0, n): B_data = T.allocate([m], "float32", scope="global") @@ -159,9 +159,9 @@ def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]): C[i, j] = A[i, j] + B[i, j] D[i, j] = C[i, j] * 2.0 - def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): - T.preflattened_buffer(A, (4, 32), "float32", data=A.data) - T.preflattened_buffer(D, (4, 32), "float32", data=D.data) + def expected(input_A: T.Buffer[(4, 32), "float32"], input_D: T.Buffer[(4, 32), "float32"]): + A = T.buffer_decl(128, "float32", data=input_A.data) + D = T.buffer_decl(128, "float32", data=input_D.data) for i, j in T.grid(4, 32): B_data = T.allocate([128], "float32", scope="global") @@ -185,9 +185,9 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for i1, j in T.grid(4, 16): C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 - def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): - T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) - T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) + def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]): + A = T.buffer_decl(256, dtype="float32", data=input_A.data) + C = T.buffer_decl(256, dtype="float32", data=input_C.data) for i0 in T.serial(0, 4): B_new_data = T.allocate([68], "float32", scope="global") B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data) @@ -206,9 +206,9 @@ def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None: for i0 in T.serial(10): B[i0] = A[i0] - def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None: - T.preflattened_buffer(A, [10], dtype="bool", data=A.data) - T.preflattened_buffer(B, [10], dtype="bool", data=B.data) + def expected(input_A: T.Buffer[10, "bool"], input_B: T.Buffer[10, "bool"]) -> None: + A = T.buffer_decl(10, dtype="int8", data=input_A.data) + B = T.buffer_decl(10, dtype="int8", data=input_B.data) # body for i0 in T.serial(10): B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 56128155295e..fe48aa7d8fd4 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -544,9 +544,6 @@ def partitioned_concat( A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"], C: T.Buffer[(32,), "float32"] ) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [16], data=A.data) - T.preflattened_buffer(B, [16], data=B.data) - T.preflattened_buffer(C, [32], data=C.data) for i in T.serial(0, 16): C[i] = A[i] for i in T.serial(0, 16): @@ -581,42 +578,46 @@ def partition_from_scheduled_tir(prim_func, pass_cfg): @T.prim_func def partitioned_concat_3( - placeholder: T.Buffer[(50176,), "int8"], - placeholder_1: T.Buffer[(25088,), "int8"], - placeholder_2: T.Buffer[(25088,), "int8"], - T_concat: T.Buffer[(100352,), "int8"], + placeholder: T.Buffer[(1, 64, 28, 28), "int8"], + placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], + placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], + T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - T.preflattened_buffer(placeholder, [1, 64, 28, 28], "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [1, 32, 28, 28], "int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, [1, 32, 28, 28], "int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, [1, 128, 28, 28], "int8", data=T_concat.data) + placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1, i2, i3 in T.grid(64, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1_flat[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2_flat[i1 * 784 + i2 * 28 + i3] @T.prim_func def concat_func_3( - placeholder: T.Buffer[(50176,), "int8"], - placeholder_1: T.Buffer[(25088,), "int8"], - placeholder_2: T.Buffer[(25088,), "int8"], - T_concat: T.Buffer[(100352,), "int8"], + placeholder: T.Buffer[(1, 64, 28, 28), "int8"], + placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], + placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], + T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) + placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_2_flat[ + i1 * 784 + i2 * 28 + i3 - 75264 + ] if 64 <= i1 and i1 < 96: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_1_flat[ + i1 * 784 + i2 * 28 + i3 - 50176 + ] if i1 < 64: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] def test_condition_mutually_exclusive(): @@ -628,9 +629,11 @@ def test_condition_mutually_exclusive(): def test_loop_partition_unroll_hint(): @T.prim_func - def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: - T.preflattened_buffer(A, [1, 3, 224, 224], "int8", data=A.data) - T.preflattened_buffer(B, [1, 224, 7, 16], "int8", data=B.data) + def main( + A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"] + ) -> None: + A = T.buffer_decl(150528, "int8", data=A_arg.data) + B = T.buffer_decl(25088, "int8", data=B_arg.data) for ax0 in T.serial( 112, annotations={"pragma_loop_partition_hint": True}, @@ -640,9 +643,11 @@ def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 - 3] @T.prim_func - def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: - T.preflattened_buffer(A, [1, 3, 224, 224], dtype="int8", data=A.data) - T.preflattened_buffer(B, [1, 224, 7, 16], dtype="int8", data=B.data) + def partitioned_main( + A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"] + ) -> None: + A = T.buffer_decl(150528, dtype="int8", data=A_arg.data) + B = T.buffer_decl(25088, dtype="int8", data=B_arg.data) # body for ax1, ax2, ax3 in T.grid(224, 7, 16): if 3 <= ax2 and ax3 < 3: @@ -688,8 +693,6 @@ def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: @T.prim_func def after(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: - T.preflattened_buffer(A, [160], dtype="int32", data=A.data) - T.preflattened_buffer(B, [160], dtype="int32", data=B.data) for i in T.serial(10, annotations={"key": "value"}): B[i] = A[i] + 1 for i in T.serial(140, annotations={"key": "value"}): @@ -737,10 +740,6 @@ def after( placeholder_2: T.Buffer[25088, "int8"], T_concat: T.Buffer[100352, "int8"], ) -> None: - T.preflattened_buffer(placeholder, [50176], dtype="int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [25088], dtype="int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, [25088], dtype="int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, [100352], dtype="int8", data=T_concat.data) for _ in T.serial(1, annotations={"preserve_unit_loop": True}): for i1, i2, i3 in T.grid(64, 28, 28): T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index bfa132d4cecf..635badb847bd 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -25,12 +25,12 @@ @tvm.script.ir_module class Before: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -44,24 +44,24 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -75,27 +75,27 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After_simplified: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.decl_buffer([8], "float32", scope="local") @@ -106,13 +106,13 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg # fmt: on diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index c80cd55ea27e..0c5d77d02b91 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -98,10 +98,10 @@ def ir(A, B): @tvm.testing.requires_cuda def test_sync_read_thread_id_independent_location(): @T.prim_func - def func(p0: T.Buffer[2, "float32"], p1: T.Buffer[2, "float32"]) -> None: + def func(p0_arg: T.Buffer[(1, 2, 1, 1), "float32"], p1: T.Buffer[2, "float32"]) -> None: threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") - T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data) + p0 = T.buffer_decl([2], dtype="float32", data=p0_arg.data) result_local = T.alloc_buffer([1], dtype="float32", scope="local") temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") T.launch_thread(blockIdx_x, 8) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 31cc6e07dec3..d1f86814e7d6 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -75,11 +75,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=64, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -90,13 +87,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=64, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=64, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body PaddedInput_7_data = T.allocate([157323], "int16", "global") PaddedInput_7 = T.buffer_decl(shape=[157323], dtype="int16", data=PaddedInput_7_data) @@ -118,9 +111,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) - T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body tensor_2_data = T.allocate([200704], "uint8", "global") tensor_2 = T.buffer_decl(shape=[200704], dtype="uint8", data=tensor_2_data) @@ -168,13 +159,9 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - T.preflattened_buffer(T_cast_7, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -189,15 +176,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") - T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - T.preflattened_buffer(T_subtract_1, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -205,17 +187,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") - T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") - T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") - T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - T.preflattened_buffer(T_cast_21, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -280,11 +256,8 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -294,13 +267,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1_data = T.allocate([379456], "int16", "global") PaddedInput_1 = T.buffer_decl(shape=[379456], dtype="int16", data=PaddedInput_1_data) @@ -321,13 +290,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2_data = T.allocate([360000], "int16", "global") PaddedInput_2 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_2_data) @@ -349,15 +314,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3_data = T.allocate([360000], "int16", "global") PaddedInput_3 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_3_data) @@ -396,13 +356,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput_data = T.allocate([360000], "int16", "global") PaddedInput = T.buffer_decl([360000], "int16", data=PaddedInput_data) @@ -426,13 +382,9 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -440,17 +392,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -470,15 +416,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -498,15 +439,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -525,15 +461,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): @@ -630,9 +561,6 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr[T.uint8]) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) - T.preflattened_buffer( - global_workspace_1_buffer_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 - ) dense_let = T.buffer_decl([10], "int32") with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")): T.evaluate( diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 32293cccdcf1..f542080f89f9 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -565,26 +565,6 @@ def non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) -def test_preflattened_buffer_map_align(): - def preflattened_buffer_map_align_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], align="bar" - ) # check_error: align: want int or IntImm, got 'bar' - - check_error(preflattened_buffer_map_align_nonint, 3) - - -def test_preflattened_buffer_map_offset_factor(): - def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], offset_factor="bar" - ) # check_error: offset_factor: want int or IntImm, got 'bar' - - check_error(preflattened_buffer_map_offset_factor_nonint, 3) - - def test_illegal_buffer_slice(): def strided_buffer_region(A: T.handle): # do not allow stride in buffer region diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 29e03f8bb63f..7d542c7bc7bd 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -41,7 +41,6 @@ def test_ir_builder_tir_primfunc_base(): body=tir.Evaluate(0), ret_type=None, buffer_map=None, - preflattened_buffer_map=None, attrs=None, ) @@ -60,7 +59,6 @@ def test_ir_builder_tir_primfunc_complete(): T.func_attr({"key": "value"}) T.func_ret(tvm.ir.PrimType("int64")) buffer_d = T.match_buffer(d, (64, 64), "int64") - T.preflattened_buffer(e, (32, 32), "int8", data=e.data) T.evaluate(0) # the prim_func generated by IRBuilder @@ -83,9 +81,6 @@ def test_ir_builder_tir_primfunc_complete(): body=tir.Evaluate(0), ret_type=tvm.ir.PrimType("int64"), buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer}, - preflattened_buffer_map={ - e_handle: tir.decl_buffer((32, 32), "int8", name="e_preflatten", data=e_buffer.data) - }, attrs=tvm.ir.make_node("DictAttrs", key="value"), ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 16f1cb04945a..e2960686ac7c 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -186,23 +186,6 @@ def test_dynamic_shape_gemm(): assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) -@T.prim_func -def preflattened_buffer_map(A: T.handle, B: T.handle): - A_1 = T.match_buffer(A, [1]) - T.preflattened_buffer(A_1, [1], align=1, offset_factor=2) - B_1 = T.match_buffer(B, [1]) - T.preflattened_buffer(B_1, [1]) - B_1[0] = A_1[0] - - -def test_preflattened_buffer_map(): - A_var = [ - k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" - ][0] - assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 - assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 - - @T.prim_func def match_buffer_int64(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")