From e97d9f1c6b0013f7147dd05eb3f2ededb377b257 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 7 Apr 2022 13:26:34 -0500 Subject: [PATCH 01/25] [Draft][TIR] Remove PrimFuncNode::preflattened_buffer_map `PrimFuncNode::preflattened_buffer_map` was introduced in https://github.com/apache/tvm/pull/9727, in order to maintain a record of the pre-flattened buffer shape until it can be used in `MakePackedAPI`. This commit instead maintains the pre-flattened shapes in `PrimFuncNode::buffer_map`, while the body of the function uses a flattened buffer alias. Passes LLVM tests in test_target_codegen_llvm.py as initial proof of concept. --- include/tvm/tir/function.h | 43 +--- python/tvm/script/context_maintainer.py | 3 - python/tvm/script/parser.py | 1 - python/tvm/script/tir/special_stmt.py | 73 ------- python/tvm/tir/function.py | 7 - src/printer/tvmscript_printer.cc | 20 -- src/relay/backend/aot_executor_codegen.cc | 2 +- .../example_target_hooks/relay_to_tir.cc | 2 +- src/relay/transforms/fold_constant.cc | 4 +- src/tir/analysis/device_constraint_utils.cc | 22 +- src/tir/contrib/ethosu/passes.cc | 5 +- src/tir/ir/function.cc | 10 +- src/tir/transforms/bf16_legalize.cc | 29 --- src/tir/transforms/flatten_buffer.cc | 17 +- src/tir/transforms/make_packed_api.cc | 4 +- src/tir/transforms/storage_flatten.cc | 201 ++++++++++-------- src/tir/usmp/transform/assign_pool_info.cc | 4 +- .../convert_pool_allocations_to_offsets.cc | 10 +- 18 files changed, 154 insertions(+), 303 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 79fbd0932a6d..81820260a318 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode { * While we could have express parameter unpacking and constraint using * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. - */ - Map buffer_map; - - /*! \brief The buffer map prior to flattening. - * - * This contains the buffers as they exists prior to flattening, and - * is used for validating an input tensor passed into the packed - * API. Any buffer that is present in `buffer_map` but not present - * in `preflattened_buffer_map` is assumed to be the same before - * and after flattening (e.g. a 1-d tensor that is backed by 1-d - * flat memory). * - * TODO(Lunderberg): Remove preflattened_buffer_map, and instead - * declare each flattened buffer as aliasing the original tensor - * shape. This should include improving the StmtExprMutator to - * provide easier interactions with Buffer objects, so that the - * bookkeeping of relationships between buffers doesn't need to be - * repeated across several transforms. + * Prior to buffer flattening, which is performed either in + * StorageFlatten for TE-based schedules or in FlattenBuffer for + * TIR-based schedules, these buffer objects are used directly in + * the body of the function. After buffer flattening, these buffer + * objects remain unflattened for use in argument validation, but + * all usage in the body of the function is done through a + * flattened alias of the buffer. */ - Map preflattened_buffer_map; + Map buffer_map; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); - v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -123,7 +112,6 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && - equal(preflattened_buffer_map, other->preflattened_buffer_map) && equal(ret_type, other->ret_type) && equal(body, other->body) && equal(attrs, other->attrs); } @@ -131,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(params); hash_reduce(buffer_map); - hash_reduce(preflattened_buffer_map); hash_reduce(ret_type); hash_reduce(body); hash_reduce(attrs); @@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc { * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally * generated as a tensor of shape ``[32, 32]``) * - * \param preflattened_buffer_map The buffer map for - * parameter buffer unpacking. This contains buffer - * objects as they are expected to be passed in by the - * callee. (e.g. a buffer of shape ``[32, 32]`` originally - * generated as a tensor of shape ``[32, 32]``) - * * \param attrs Additional function attributes. * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc( - Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), - Optional> preflattened_buffer_map = Optional>(), - DictAttrs attrs = NullValue(), Span span = Span()); + TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), + Map buffer_map = Map(), + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index f7f16855c752..b84b7d398084 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -129,8 +129,6 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" - func_preflattened_buffer_map: Mapping[Var, Buffer] = {} - """Mapping[Var, Buffer]: The function buffer map, prior to any flattening.""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} @@ -160,7 +158,6 @@ def __init__( # function context self.func_params = [] self.func_buffer_map = {} - self.func_preflattened_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} # parser and analyzer diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index fe71b064320f..8c91171435ac 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -487,7 +487,6 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: body, ret_type, buffer_map=self.context.func_buffer_map, - preflattened_buffer_map=self.context.func_preflattened_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, span=tvm_span_from_synr(node.span), ) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 15502055b7fc..5e6c7c8fe739 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -868,79 +868,6 @@ def func_attr(dict_attr, span): super().__init__(func_attr, def_symbol=False) -@register -class PreflattenedBufferMap(SpecialStmt): - """Special Stmt for declaring the PrimFunc::preflattened_buffer_map - - Example - ------- - .. code-block:: python - A0 = T.match_buffer(A, (48,), dtype="float32") - T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32") - """ - - def __init__(self): - def preflattened_buffer( - postflattened, - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - span=None, - ): - - param = None - for key, value in self.context.func_buffer_map.items(): - if value.same_as(postflattened): - param = key - break - - assert ( - param is not None - ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." - - if data is None: - data = self.context.func_buffer_map[param].data - - buffer_name: str = f"{postflattened.name}_preflatten" - if align != -1: - if isinstance(align, IntImm): - align = align.value - else: - assert isinstance(align, int), f"align: want int or IntImm, got {align!r}" - - if offset_factor != 0: - if isinstance(offset_factor, IntImm): - offset_factor = offset_factor.value - else: - assert isinstance( - offset_factor, int - ), f"offset_factor: want int or IntImm, got {offset_factor!r}" - - preflattened = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - span=span, - ) - - self.context.func_preflattened_buffer_map[param] = preflattened - - super().__init__(preflattened_buffer, def_symbol=False) - - @register class TargetAttrValue(SpecialStmt): """Special Stmt for target attr value. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index d84513e072d3..961b53fcf8da 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -47,9 +47,6 @@ class PrimFunc(BaseFunc): buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] The buffer binding map. - preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]] - The buffer binding map, prior to any flattening. - attrs: Optional[tvm.Attrs] Attributes of the function, can be None @@ -63,14 +60,12 @@ def __init__( body, ret_type=None, buffer_map=None, - preflattened_buffer_map=None, attrs=None, span=None, ): param_list = [] buffer_map = {} if buffer_map is None else buffer_map - preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map for x in params: x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): @@ -88,7 +83,6 @@ def __init__( body, ret_type, buffer_map, - preflattened_buffer_map, attrs, span, ) # type: ignore @@ -114,7 +108,6 @@ def with_body(self, new_body, span=None): new_body, self.ret_type, self.buffer_map, - self.preflattened_buffer_map, self.attrs, span, ) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 99d1a7845d3f..a138236de381 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1541,26 +1541,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[buf]; body << ")" << Doc::NewLine(); } - // print preflattened buffer map - for (const auto& param : op->params) { - auto pf_buf_it = op->preflattened_buffer_map.find(param); - if (pf_buf_it != op->preflattened_buffer_map.end()) { - const Buffer& preflattened = (*pf_buf_it).second; - - auto buf_it = op->buffer_map.find(param); - ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name - << " with no corresponding post-flatten buffer."; - const Buffer& postflattened = (*buf_it).second; - - // Call Print() without assigning in order to fill memo_buf_decl_. - Print(preflattened); - buf_not_in_headers_.insert(preflattened.get()); - ICHECK(memo_buf_decl_.count(preflattened)); - - body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", " - << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine(); - } - } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 60f108aacf66..71c7ce429584 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -803,7 +803,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, DictAttrs(dict_attrs)); } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c498baa6d11d..7c84494c56b0 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -81,7 +81,7 @@ class ConvertAddToSubtract : public MixedModeMutator { }; tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, {}, DictAttrs(dict_attrs)); + buffer_map, DictAttrs(dict_attrs)); // Switch to TIRToRuntime hook for testing Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 9dec840be0a7..1ddb0e44eac1 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -259,7 +259,9 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - With fresh_build_ctx(transform::PassContext::Create()); + auto context = transform::PassContext::Create(); + context->instruments = transform::PassContext::Current()->instruments; + With fresh_build_ctx(context); Map dict = (module_->attrs.defined()) ? Map(module_->attrs.CopyOnWrite()->dict) diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 1309681513a9..9a1e5ba38cad 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -210,8 +210,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { // Start with a copy of the current prim_func buffer map. Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); - Map new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(), - prim_func->preflattened_buffer_map.end()); bool any_change = false; // For each constrained parameter... @@ -225,23 +223,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { any_change = true; } new_buffer_map.Set(param, new_buffer); - - // Rewrite the pre-flattened buffers to account for constraint. - // This only has an impact if the IRModule being analyzed has - // already been run through the StorageFlatten or FlattenBuffer - // passes. - if (auto opt = prim_func->preflattened_buffer_map.Get(param)) { - Buffer pf_buffer = opt.value(); - if (pf_buffer.same_as(buffer)) { - new_preflattened_buffer_map.Set(param, new_buffer); - } else { - const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device); - if (!new_buffer.same_as(pf_buffer)) { - any_change = true; - } - new_preflattened_buffer_map.Set(param, new_buffer); - } - } } // Make sure we have accounted for all prim_func parameters. CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); @@ -259,8 +240,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (any_change) { return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, - std::move(new_buffer_map), std::move(new_preflattened_buffer_map), - prim_func->attrs, prim_func->span); + std::move(new_buffer_map), prim_func->attrs, prim_func->span); } else { return prim_func; } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 45161499f5be..fc48df52682c 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -76,9 +76,8 @@ class HoistAllocatesMutator : public StmtExprMutator { current_alloc->span); } - PrimFunc new_main_func = - PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map, - main_func->preflattened_buffer_map, main_func->attrs); + PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, + main_func->buffer_map, main_func->attrs); return new_main_func; } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index b9c3029d3c25..f58dd8aa820c 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,9 +29,7 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, - Optional> preflattened_buffer_map, DictAttrs attrs, - Span span) { + Map buffer_map, DictAttrs attrs, Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -42,7 +40,6 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->body = std::move(body); n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); - n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map()); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); n->span = std::move(span); @@ -121,9 +118,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, - Map preflattened_buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span); + Map buffer_map, DictAttrs attrs, Span span) { + return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); TVM_REGISTER_GLOBAL("tir.TensorIntrin") diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 193584f84b47..e6a5964a7a0f 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -287,37 +287,8 @@ class BF16LowerRewriter : public StmtExprMutator { } } - // Most passes do not change the preflattened buffer map, nor - // should they change it. This is an exception, because the Var - // associated with the `BufferNode::data` in - // `PrimFunc::buffer_map` may be replaced, and the corresponding - // Var in the `PrimFunc::preflattened_buffer_map` must also be - // replaced. - Map new_preflattened_buffer_map; - for (auto& itr : op->preflattened_buffer_map) { - auto param_var = itr.first; - auto oldbuf = itr.second; - if (oldbuf->dtype.is_bfloat16()) { - auto it = new_buffer_map.find(param_var); - ICHECK(it != new_buffer_map.end()) - << "PrimFunc parameter " << param_var->name_hint - << " is associated with the pre-flattened buffer " << oldbuf->name - << ", but isn't associated with any post-flatten buffer."; - const Buffer& flatbuf = (*it).second; - DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); - auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides, - oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment, - oldbuf->offset_factor, oldbuf->buffer_type); - buffer_remap_[oldbuf] = newbuf; - new_preflattened_buffer_map.Set(param_var, newbuf); - } else { - new_preflattened_buffer_map.Set(param_var, oldbuf); - } - } - if (buffer_remap_.size() != 0) { op->buffer_map = new_buffer_map; - op->preflattened_buffer_map = new_preflattened_buffer_map; } } diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c7cc51d27113..378eeb2023c3 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -51,24 +51,19 @@ PrimExpr BufferArea(const Buffer& buffer) { class BufferFlattener : public StmtExprMutator { public: static PrimFunc Flatten(PrimFunc func) { - Map preflattened_buffer_map = - Merge(func->buffer_map, func->preflattened_buffer_map); - - auto pass = BufferFlattener(func->buffer_map); + auto pass = BufferFlattener(); auto writer = func.CopyOnWrite(); writer->body = pass.VisitStmt(func->body); - writer->preflattened_buffer_map = preflattened_buffer_map; - writer->buffer_map = pass.updated_extern_buffer_map_; + // The buffers in func->buffer_map are deliberately left + // unflattened, as they are used for validation of user-provided + // arguments. The flattened buffers used in the updated + // function body alias the argument buffers. return func; } private: - explicit BufferFlattener(const Map& extern_buffer_map) { - for (const auto& kv : extern_buffer_map) { - updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second)); - } - } + BufferFlattener() {} Stmt VisitStmt_(const BlockRealizeNode* op) final { // We have convert blocks into opaque blocks in previous passes. diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 35c96e4fe4e1..7af7f97991fc 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -224,9 +224,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { continue; } - if (func_ptr->preflattened_buffer_map.count(param)) { - buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]); - } else if (func_ptr->buffer_map.count(param)) { + if (func_ptr->buffer_map.count(param)) { buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); } else { var_def.emplace_back(v_arg, param); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f97f91a1e501..60009cf9dfc8 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -401,6 +401,7 @@ class BufferStrideLegalize : public StmtExprMutator { auto fptr = func.CopyOnWrite(); fptr->body = pass(std::move(fptr->body)); + fptr->buffer_map = pass.UpdatedExternBufferMap(); if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); } @@ -419,7 +420,6 @@ class BufferStrideLegalize : public StmtExprMutator { BufferEntry entry; entry.remap_to = with_strides; entry.in_scope = true; - entry.is_external = true; buf_map_[buf] = entry; } updated_extern_buffer_map_.Set(kv.first, with_strides); @@ -442,51 +442,54 @@ class BufferStrideLegalize : public StmtExprMutator { Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } Buffer WithStrides(Buffer buf) { - auto it = buf_map_.find(buf); + auto cache_key = buf; + + auto it = buf_map_.find(cache_key); if (it != buf_map_.end()) { const BufferEntry& entry = it->second; ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer"; return entry.remap_to; } + Array shape = buf->shape; + if (buf->strides.size()) { ICHECK_EQ(buf->strides.size(), buf->shape.size()) << "Buffer " << buf << " has inconsistent strides/shape."; - return buf; - } - - // Keeping this to have matched behavior to previous version. - // There are many parts of the codebase that assume that a strided - // array cannot be compact. For example, ArgBinder::BindBuffer - // and tir.Specialize. - if (dim_align_.count(buf) == 0) { - return buf; - } - - // Can't define the strides for a buffer without a known shape. - Array shape = buf->shape; - if (shape.size() == 0) { - return buf; - } - - std::vector rstrides; - const std::vector& avec = dim_align_[buf]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = bound_analyzer_->Simplify(stride); + } else if (dim_align_.count(buf) == 0) { + // Keeping this to have matched behavior to previous version. + // There are many parts of the codebase that assume that a + // strided array cannot be compact. For example, + // ArgBinder::BindBuffer and tir.Specialize. To avoid breaking + // these, do not define the strides unless required for a + // non-compact array. + } else if (shape.size() == 0) { + // Can't define the strides for a buffer without a known shape. + } else { + // With everything checked, can now define the updated strides + std::vector rstrides; + const std::vector& avec = dim_align_[buf]; + int first_dim = 0; + PrimExpr stride = make_const(shape[first_dim].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (dim < avec.size() && avec[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + stride = bound_analyzer_->Simplify(stride); + } + rstrides.push_back(stride); + stride = stride * shape[dim]; } - rstrides.push_back(stride); - stride = stride * shape[dim]; + + buf.CopyOnWrite()->strides = Array(rstrides.rbegin(), rstrides.rend()); } - auto ptr = buf.CopyOnWrite(); - ptr->strides = Array(rstrides.rbegin(), rstrides.rend()); + BufferEntry entry; + entry.remap_to = buf; + entry.in_scope = true; + buf_map_[cache_key] = entry; return buf; } @@ -512,16 +515,10 @@ class BufferStrideLegalize : public StmtExprMutator { Buffer target_with_strides = WithStrides(Downcast(arr[1])); Buffer source_with_strides = WithStrides(source); - { - BufferEntry entry; - entry.remap_to = source_with_strides; - entry.in_scope = true; - entry.is_external = false; - buf_map_[source] = entry; - } - Stmt body = this->VisitStmt(op->body); + buf_map_[source].in_scope = false; + return AttrStmt(Array{source_with_strides, target_with_strides}, op->attr_key, op->value, body, op->span); } else { @@ -559,13 +556,6 @@ class BufferStrideLegalize : public StmtExprMutator { Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); - { - BufferEntry entry; - entry.remap_to = with_strides; - entry.in_scope = true; - entry.is_external = false; - buf_map_[key] = entry; - } Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -588,22 +578,14 @@ class BufferStrideLegalize : public StmtExprMutator { template Node VisitBufferAccess(Node node) { - auto alloc_key = node->buffer->data.get(); - if (!buf_map_.count(node->buffer) && buffer_var_defines_.count(alloc_key)) { - BufferEntry entry; - entry.remap_to = WithStrides(node->buffer); - entry.in_scope = true; - entry.is_external = false; - buf_map_[node->buffer] = entry; - } - auto it = buf_map_.find(node->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << node->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot access a buffer " << node->buffer->name << ", out of scope"; + ICHECK(it == buf_map_.end() || it->second.in_scope) + << "Cannot access a buffer " << node->buffer->name << ", out of scope"; - auto writer = node.CopyOnWrite(); - writer->buffer = e.remap_to; + auto with_strides = WithStrides(node->buffer); + if (!with_strides.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = with_strides; + } return node; } @@ -622,7 +604,6 @@ class BufferStrideLegalize : public StmtExprMutator { struct BufferEntry { Buffer remap_to; bool in_scope; - bool is_external; }; std::unordered_map buf_map_; @@ -845,6 +826,7 @@ class BufferBindUnwrapper : public StmtExprMutator { BufferEntry e; e.buffer = kv.second; e.external = true; + var_to_buffer_[kv.second->data.get()] = kv.second; buf_map_[kv.second.get()] = std::move(e); } } @@ -1000,6 +982,7 @@ class BufferBindUnwrapper : public StmtExprMutator { BufferEntry e; e.bounds = op->bounds; e.buffer = op->buffer; + var_to_buffer_[op->buffer->data.get()] = op->buffer; buf_map_[key] = std::move(e); } @@ -1088,6 +1071,7 @@ class BufferBindUnwrapper : public StmtExprMutator { source_info.buffer = source; source_info.remap = std::make_unique(remap); + var_to_buffer_[source->data.get()] = source; buf_map_[source.get()] = std::move(source_info); } @@ -1153,18 +1137,70 @@ class BufferBindUnwrapper : public StmtExprMutator { }; const BufferEntry& GetBufferEntry(Buffer buffer) { - auto alloc_key = buffer->data.get(); - if (!buf_map_.count(buffer.get()) && buffer_var_defines_.count(alloc_key)) { + if (buf_map_.count(buffer.get())) { + const BufferEntry& e = buf_map_[buffer.get()]; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return e; + } else if (buffer_var_defines_.count(buffer->data.get())) { + // The buffer var was defined, but the buffer hasn't been seen + // before. BufferEntry entry; entry.buffer = buffer; + var_to_buffer_[buffer->data.get()] = buffer; buf_map_[buffer.get()] = std::move(entry); - } + return buf_map_[buffer.get()]; + } else if (var_remap_.count(buffer->data.get())) { + // The buffer var is an alias of a bound buffer. Only + // supported if the bound buffer has no offsets. In this + // case, we just need to make a new aliasing buffer that + // shares the remapped data variable. + Var old_var = buffer->data; + Var new_var = Downcast(var_remap_[old_var.get()]); - auto it = buf_map_.find(buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; - return it->second; + { + ICHECK(var_to_buffer_.count(old_var.get())) + << "Cannot find remap information for aliased buffer var " << old_var->name_hint + << ", required to verify this alias is legal."; + const Buffer& aliased_buffer = var_to_buffer_[old_var.get()]; + const BufferEntry& entry = buf_map_[aliased_buffer.get()]; + if (entry.remap) { + for (const auto& begin : entry.remap->begins) { + ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported"; + } + } + } + + { + Buffer new_buf = buffer; + new_buf.CopyOnWrite()->data = new_var; + + RemapInfo remap_info; + remap_info.target = new_buf; + remap_info.begins = Array(buffer->shape.size(), 0); + remap_info.extents = buffer->shape; + + BufferEntry entry; + entry.buffer = buffer; + entry.remap = std::make_unique(remap_info); + entry.in_scope = true; + var_to_buffer_[buffer->data.get()] = buffer; + buf_map_[buffer.get()] = std::move(entry); + } + return buf_map_[buffer.get()]; + } else if (var_to_buffer_.count(buffer->data.get())) { + // This buffer is an alias of a known buffer, with no remaps. A + // buffer entry should be generated and returned. + BufferEntry entry; + entry.buffer = buffer; + entry.in_scope = true; + var_to_buffer_[buffer->data.get()] = buffer; + buf_map_[buffer.get()] = std::move(entry); + + return buf_map_[buffer.get()]; + } else { + LOG(FATAL) << "Can't work around the undefined buffer"; + return *static_cast(nullptr); + } } // The buffer assignment map @@ -1174,6 +1210,9 @@ class BufferBindUnwrapper : public StmtExprMutator { std::unordered_set illegal_vars_; // Buffer map std::unordered_map buf_map_; + // Map from Var to the Buffer they occurred in. In case of aliased + // buffers, contains the first buffer. + std::unordered_map var_to_buffer_; // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. std::unordered_set buffer_var_defines_; @@ -1304,13 +1343,12 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - Map preflattened_buffer_map = - Merge(func->buffer_map, func->preflattened_buffer_map); - auto fptr = func.CopyOnWrite(); fptr->body = pass(std::move(fptr->body)); - fptr->preflattened_buffer_map = preflattened_buffer_map; - fptr->buffer_map = pass.UpdatedBufferMap(); + // The buffers in func->buffer_map are deliberately left + // unflattened, as they are used for validation of user-provided + // arguments. The flattened buffers used in the updated + // function body alias the argument buffers. return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1338,15 +1376,12 @@ class StorageFlattener : public StmtExprMutator { } } e.external = true; + buffer_var_defines_.insert(kv.second->data.get()); buf_map_[kv.second] = e; - - updated_extern_buffer_map_.Set(kv.first, e.flattened_buffer); } cache_line_size_ = cache_line_size; } - Map UpdatedBufferMap() { return updated_extern_buffer_map_; } - Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); @@ -1505,8 +1540,10 @@ class StorageFlattener : public StmtExprMutator { writer->dtype = DataType::Int(8); } + buffer_var_defines_.insert(op->buffer->data.get()); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); + buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; Stmt ret = @@ -1770,8 +1807,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map> buffer_var_map_; // Buffer map std::unordered_map buf_map_; - // The extern buffer map, updated to include flattened buffers. - Map updated_extern_buffer_map_; // Collects shapes. std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index e291eaa0519e..d0a18b35727e 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -111,8 +111,8 @@ IRModule PoolInfoAssigner::operator()() { if (kv.second->IsInstance()) { func_ = Downcast(kv.second); Stmt body = this->VisitStmt(func_->body); - PrimFunc new_prim_func = PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, - func_->preflattened_buffer_map, func_->attrs); + PrimFunc new_prim_func = + PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); mod_->Update(gv, new_prim_func); } } diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index dc71e3d60891..3f5d13768875 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -227,8 +227,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, - si.buffer_map, original_attrs); + PrimFunc ret = + PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } @@ -379,12 +379,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() { // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (!emit_tvmscript_printable_) { - main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, - main_func->attrs); + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); } else { main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, DictAttrs()); + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); } module_->Update(gv, main_func); if (!emit_tvmscript_printable_) { From 33b5b63aea829c4df9ea8679919f1862899be535 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Apr 2022 08:48:00 -0500 Subject: [PATCH 02/25] Fix lint errors --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index baadede08d66..fbfdbe9b3b85 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -298,7 +298,6 @@ def _ftransform(f, mod, ctx): new_body, f.ret_type, new_buffer_map, - f.preflattened_buffer_map, f.attrs, f.span, ) @@ -637,7 +636,6 @@ def _ftransform(f, mod, ctx): new_body, f.ret_type, new_buffer_map, - f.preflattened_buffer_map, f.attrs, f.span, ) @@ -872,7 +870,6 @@ def CreatePrimFuncWithoutConstants(const_dict): def _ftransform(f, mod, ctx): new_params = list() new_buffer_map = dict() - new_preflattened_buffer_map = dict() for param_idx in const_dict.keys(): # We are using buffer_var to key the constants as # PrimFunc params of constants will be removed. @@ -881,14 +878,11 @@ def _ftransform(f, mod, ctx): if i not in const_dict.keys(): new_params.append(param) new_buffer_map[param] = f.buffer_map[param] - if param in f.preflattened_buffer_map: - new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param] return tvm.tir.PrimFunc( new_params, f.body, f.ret_type, new_buffer_map, - new_preflattened_buffer_map, f.attrs, f.span, ) From d855300c6030b154b2b9f8684bf82d27c54833ce Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Apr 2022 08:48:29 -0500 Subject: [PATCH 03/25] Replacing T.preflattened_buffer with T.buffer_decl in unit tests --- .../test_ethosu/test_encode_constants.py | 24 ++--- .../test_ethosu/test_hoist_allocates.py | 31 +++---- .../test_ethosu/test_remove_concatenates.py | 8 +- .../test_ethosu/test_replace_conv2d.py | 71 +++++++------- .../contrib/test_ethosu/test_replace_copy.py | 12 +-- .../contrib/test_ethosu/test_scheduler.py | 6 +- .../test_hexagon/test_2d_physical_buffers.py | 2 +- .../unittest/test_auto_scheduler_feature.py | 12 +-- tests/python/unittest/test_lower_build.py | 24 ++--- .../test_tir_transform_flatten_buffer.py | 92 +++++++------------ .../test_tir_transform_loop_partition.py | 35 ++++--- ...tir_transform_renormalize_split_pattern.py | 24 ++--- 12 files changed, 156 insertions(+), 185 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 92e6cd3e19cb..f26cbf6871a8 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,7 +34,7 @@ @tvm.script.ir_module class WeightStreamOnlyU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([128], "uint8") @@ -45,8 +45,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_5 = T.buffer_decl([32], "uint8") buffer_6 = T.buffer_decl([112], "uint8") buffer_7 = T.buffer_decl([32], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], "int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=ethosu_write.data) # body p1_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) p2_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) @@ -169,13 +169,13 @@ def _get_func(): @tvm.script.ir_module class RereadWeightsU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([304], "uint8") buffer_1 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], "int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=ethosu_write.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) @@ -275,15 +275,15 @@ def _get_func(): @tvm.script.ir_module class DirectReadOnlyU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([592], "uint8") buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([160], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], "int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -371,7 +371,7 @@ def _get_func(): @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([592], "uint8") @@ -384,8 +384,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_7 = T.buffer_decl([32], "uint8") buffer_8 = T.buffer_decl([80], "uint8") buffer_9 = T.buffer_decl([32], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + placeholder = T.buffer_decl([2048], "int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) diff --git a/tests/python/contrib/test_ethosu/test_hoist_allocates.py b/tests/python/contrib/test_ethosu/test_hoist_allocates.py index b54b92950180..b12024aebf46 100644 --- a/tests/python/contrib/test_ethosu/test_hoist_allocates.py +++ b/tests/python/contrib/test_ethosu/test_hoist_allocates.py @@ -106,15 +106,15 @@ def test_double_convolution(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(3402,), "int8"], placeholder_encoded: T.Buffer[(128,), "uint8"], placeholder_encoded_1: T.Buffer[(32,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"], placeholder_encoded_3: T.Buffer[(32,), "uint8"], ethosu_write: T.Buffer[(3402,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 27, 42, 3), "int8"], placeholder_encoded: T.Buffer[(3, 3, 2, 3), "uint8"], placeholder_encoded_1: T.Buffer[(3, 10), "uint8"], placeholder_encoded_2: T.Buffer[(3, 3, 2, 3), "uint8"], placeholder_encoded_3: T.Buffer[(3, 10), "uint8"], ethosu_write: T.Buffer[(1, 27, 42, 3), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 27, 42, 3], dtype="int8", data=placeholder.data) - T.preflattened_buffer(placeholder_encoded, [3, 3, 2, 3], dtype="int8") - T.preflattened_buffer(placeholder_encoded_1, [3, 10], dtype="uint8") - T.preflattened_buffer(placeholder_encoded_2, [3, 3, 2, 3], dtype="int8") - T.preflattened_buffer(placeholder_encoded_3, [3, 10], dtype="uint8") - T.preflattened_buffer(ethosu_write, [1, 27, 42, 3], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([3402], dtype="int8", data=placeholder.data) + placeholder_encoded = T.buffer_decl([128], dtype="int8", data=placeholder_encoded.data) + placeholder_encoded_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_encoded_1.data) + placeholder_encoded_2 = T.buffer_decl([128], dtype="int8", data=placeholder_encoded_2.data) + placeholder_encoded_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_encoded_3.data) + ethosu_write = T.buffer_decl([3402], dtype="int8", data=ethosu_write.data) # body placeholder_global = T.allocate([128], "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 128, placeholder_global[0], dtype="handle")) @@ -145,11 +145,10 @@ def test_identities(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(24,), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 2, 3, 4), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 2, 3, 4], dtype="int8", data=placeholder.data) - T.preflattened_buffer(T_concat, [24], dtype="int8", data=T_concat.data) + placeholder = T.buffer_decl([24], dtype="int8", data=placeholder.data) # body ethosu_write = T.allocate([12], "int8", "global") T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 3, 4, 1, 0, 3, placeholder[12], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 3, 4, 1, 0, 3, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -179,11 +178,11 @@ def test_outer_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=ethosu_write.data) # body with T.allocate([128], "uint8", "global") as placeholder_global: T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, placeholder_global[0], dtype="handle")) @@ -221,11 +220,11 @@ def test_allocate_without_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=ethosu_write.data) # body placeholder_global = T.allocate([128], "uint8", "global") placeholder_global_1 = T.allocate([112], "uint8", "global") diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index cc996e59412c..b575fd84c6b1 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -30,7 +30,7 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,), "int8"], T_concat: T.Buffer[(4096,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([2992], "uint8") @@ -41,9 +41,9 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) buffer_5 = T.buffer_decl([160], "uint8") buffer_6 = T.buffer_decl([2992], "uint8") buffer_7 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data) - T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) + placeholder = T.buffer_decl([1536], "int8", data=placeholder.data) + placeholder_1 = T.buffer_decl([1280], "int8", data=placeholder_1.data) + T_concat = T.buffer_decl([4096], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 63f9fc44c778..f4a3a7b60d07 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -366,15 +366,15 @@ def _visit(stmt): @tvm.script.ir_module class Conv2dDoubleCascade1: @T.prim_func - def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([304], "uint8") buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([192], 'int8', data=placeholder_5.data) + ethosu_write_1 = T.buffer_decl([512], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -387,15 +387,15 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([80], "uint8") buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([192], 'int8', data=placeholder_5.data) + ethosu_write_1 = T.buffer_decl([512], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -408,15 +408,15 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, @tvm.script.ir_module class Conv2dDoubleCascade3: @T.prim_func - def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640,), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1744], "uint8") buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([880], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([768], 'int8', data=placeholder_5.data) + ethosu_write_1 = T.buffer_decl([640], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -431,15 +431,15 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, @tvm.script.ir_module class Conv2dDoubleCascade4: @T.prim_func - def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1456], "uint8") buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([11040], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([1024], 'int8', data=placeholder_5.data) + ethosu_write_1 = T.buffer_decl([2048], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -452,15 +452,15 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 @tvm.script.ir_module class Conv2dDoubleCascade5: @T.prim_func - def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) + placeholder = T.buffer_decl([192], 'int8', data=placeholder.data) + ethosu_write = T.buffer_decl([8192], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -473,15 +473,15 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), @tvm.script.ir_module class Conv2dDoubleCascade6: @T.prim_func - def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,), "int8"]) -> None: + def main(placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([1456], "uint8") buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([11040], "uint8") buffer_3 = T.buffer_decl([272], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) + placeholder = T.buffer_decl([1024], 'int8', data=placeholder.data) + ethosu_write = T.buffer_decl([32768], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) @@ -636,13 +636,13 @@ def _get_func( @tvm.script.ir_module class Conv2dInlineCopy1: @T.prim_func - def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([848], "uint8") buffer_1 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([960], 'int8', data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([1024], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -651,13 +651,13 @@ def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024 @tvm.script.ir_module class Conv2dInlineCopy2: @T.prim_func - def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([656], "uint8") - T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([315], 'int8', data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([240], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -695,13 +695,13 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): @tvm.script.ir_module class Conv2dInlineReshape1: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([192], 'int8', data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -711,13 +711,13 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape2: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([192], 'int8', data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -727,13 +727,13 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape3: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([192], 'int8', data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -743,13 +743,12 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, @tvm.script.ir_module class Conv2dInlineReshape4: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) + ethosu_write_1 = T.buffer_decl([768], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 4f06695b25b1..6ba30a940728 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,13 +31,13 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([80], "uint8") buffer_1 = T.buffer_decl([304], "uint8") - T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data) + placeholder_3 = T.buffer_decl([8192], dtype="int8", data=placeholder_3.data) + ethosu_write_1 = T.buffer_decl([2048], dtype="int8", data=ethosu_write_1.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) @@ -77,15 +77,15 @@ def _get_func(): @tvm.script.ir_module class WeightStream: @T.prim_func - def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([416], "uint8") buffer_1 = T.buffer_decl([112], "uint8") buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([64], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([8192], dtype="int8", data=placeholder_5.data) + ethosu_write_1 = T.buffer_decl([4096], dtype="int8", data=ethosu_write_1.data) # body placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_global_unrolled_iter_1 = T.buffer_decl([272], "uint8", data=placeholder_global_unrolled_iter_0.data) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 8a83e769141d..0d968b879686 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -180,10 +180,10 @@ def test_schedule_cache_reads(): @tvm.script.ir_module class DiamondGraphTir: @T.prim_func - def main(input_buffer: T.Buffer[(301056,), "int8"], output_buffer: T.Buffer[(75264,), "int8"]) -> None: + def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(input_buffer, [1, 56, 56, 96], dtype='int8', data=input_buffer.data) - T.preflattened_buffer(output_buffer, [1, 56, 56, 24], dtype='int8', data=output_buffer.data) + input_buffer = T.buffer_decl([301056], dtype='int8', data=input_buffer.data) + output_buffer = T.buffer_decl([75264], dtype='int8', data=output_buffer.data) weight_buffer = T.buffer_decl([2608], "uint8") bias_buffer = T.buffer_decl([240], "uint8") diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py index 9de55996b031..a3689e9d52e3 100644 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -242,7 +242,7 @@ def uses_unsupported_physical_dimensions( def test_param_shapes(self, ir_module, transformed_input_shape, transformed_output_shape): func = ir_module["main"] primfunc_input_shape, primfunc_output_shape = [ - list(func.preflattened_buffer_map[param].shape) for param in func.params + list(func.buffer_map[param].shape) for param in func.params ] assert primfunc_input_shape == transformed_input_shape assert primfunc_output_shape == transformed_output_shape diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 2a058cdbc05c..b8da618044b7 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -203,15 +203,15 @@ def test_gpu_feature(): @T.prim_func def tir_matmul( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(256, 256), "float32"], + B: T.Buffer[(256, 256), "float32"], + C: T.Buffer[(256, 256), "float32"], ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data) - T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data) - T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data) + A = T.buffer_decl([16384], dtype="float32", data=A.data) + B = T.buffer_decl([16384], dtype="float32", data=B.data) + C = T.buffer_decl([16384], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = T.float32(0) diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index bd820b617c2d..c8aae78a4559 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -54,15 +54,15 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: class LoweredModule: @T.prim_func def main( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], data=A.data) - T.preflattened_buffer(B, [128, 128], data=B.data) - T.preflattened_buffer(C, [128, 128], data=C.data) + A = T.buffer_decl([16384], data=A.data) + B = T.buffer_decl([16384], data=B.data) + C = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = 0.0 @@ -74,15 +74,15 @@ def main( class LoweredTIRModule: @T.prim_func def main( - A: T.Buffer[(16384,), "float32"], - B: T.Buffer[(16384,), "float32"], - C: T.Buffer[(16384,), "float32"], + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [128, 128], data=A.data) - T.preflattened_buffer(B, [128, 128], data=B.data) - T.preflattened_buffer(C, [128, 128], data=C.data) + A = T.buffer_decl([16384], data=A.data) + B = T.buffer_decl([16384], data=B.data) + C = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = 0.0 diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 68b1ad338964..4cadfd19a05c 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -28,9 +28,9 @@ def _check(original, transformed): @T.prim_func -def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") +def compacted_elementwise_func( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) @@ -49,11 +49,11 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) +def flattened_elementwise_func( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: + A = T.buffer_decl((16, 16), dtype="float32", data=A.data) + C = T.buffer_decl((16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): @@ -63,9 +63,7 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def compacted_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") +def compacted_gpu_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): for i2 in T.thread_binding(0, 2, thread="vthread"): @@ -86,11 +84,9 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) +def flattened_gpu_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + A = T.buffer_decl([256], dtype="float32", data=A.data) + C = T.buffer_decl([256], dtype="float32", data=C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -130,10 +126,10 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> @T.prim_func def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, n * m, "float32") - C = T.match_buffer(c, n * m, "float32") - T.preflattened_buffer(A, (n, m), "float32", data=A.data) - T.preflattened_buffer(C, (n, m), "float32", data=C.data) + A = T.match_buffer(a, [n, m], "float32") + C = T.match_buffer(c, [n, m], "float32") + A = T.buffer_decl([n * m], "float32", data=A.data) + C = T.buffer_decl([n * m], "float32", data=C.data) for i in range(0, n): B = T.allocate([m], "float32", "global") @@ -144,10 +140,7 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> @T.prim_func -def compacted_predicate_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - +def compacted_predicate_func(A: T.Buffer[(32,), "float32"], C: T.Buffer[(32,), "float32"]) -> None: for i, j in T.grid(5, 7): with T.block() as []: T.reads(A[i * 7 + j]) @@ -157,22 +150,14 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def flattened_predicate_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(C, (32), "float32", data=C.data) - +def flattened_predicate_func(A: T.Buffer[(32,), "float32"], C: T.Buffer[(32,), "float32"]) -> None: for i, j in T.grid(5, 7): if i * 7 + j < 32: C[i * 7 + j] = A[i * 7 + j] + 1.0 @T.prim_func -def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - +def compacted_unit_loop_func(A: T.Buffer[(32,), "float32"], C: T.Buffer[(32,), "float32"]) -> None: for x, y, z in T.grid(4, 1, 8): with T.block() as []: T.reads(A[x * 8 + y * 8 + z]) @@ -181,21 +166,15 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(C, (32), "float32", data=C.data) - +def flattened_unit_loop_func(A: T.Buffer[(32,), "float32"], C: T.Buffer[(32,), "float32"]) -> None: for x, z in T.grid(4, 8): C[x * 8 + z] = A[x * 8 + z] + 1.0 @T.prim_func -def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - D = T.match_buffer(d, (32), "float32") - +def compacted_multi_alloc_func( + A: T.Buffer[(32,), "float32"], D: T.Buffer[(32,), "float32"] +) -> None: for i in range(0, 32): with T.block() as []: T.reads(A[i]) @@ -208,12 +187,9 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: @T.prim_func -def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - D = T.match_buffer(d, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(D, (32), "float32", data=D.data) - +def flattened_multi_alloc_func( + A: T.Buffer[(32,), "float32"], D: T.Buffer[(32,), "float32"] +) -> None: for i in range(0, 32): B = T.allocate((32,), "float32", "global") C = T.allocate((32,), "float32", "global") @@ -223,9 +199,9 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: @T.prim_func -def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") +def compacted_strided_buffer_func( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: for i0 in range(0, 4): with T.block(): T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) @@ -246,11 +222,11 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (256,), "float32") - C = T.match_buffer(c, (256,), "float32") - T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) - T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) +def flattened_strided_buffer_func( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: + A = T.buffer_decl([256], dtype="float32", data=A.data) + C = T.buffer_decl([256], dtype="float32", data=C.data) for i0 in T.serial(0, 4): B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 6cfe96664d89..63f29e90aea6 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -544,9 +544,6 @@ def partitioned_concat( A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"], C: T.Buffer[(32,), "float32"] ) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(A, [16], data=A.data) - T.preflattened_buffer(B, [16], data=B.data) - T.preflattened_buffer(C, [32], data=C.data) for i in T.serial(0, 16): C[i] = A[i] for i in T.serial(0, 16): @@ -570,15 +567,15 @@ def test_explicit_partition_hint(): @T.prim_func def partitioned_concat_3( - placeholder: T.Buffer[(50176,), "int8"], - placeholder_1: T.Buffer[(25088,), "int8"], - placeholder_2: T.Buffer[(25088,), "int8"], - T_concat: T.Buffer[(100352,), "int8"], + placeholder: T.Buffer[(1, 64, 28, 28), "int8"], + placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], + placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], + T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - T.preflattened_buffer(placeholder, [1, 64, 28, 28], "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [1, 32, 28, 28], "int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, [1, 32, 28, 28], "int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, [1, 128, 28, 28], "int8", data=T_concat.data) + placeholder = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1 = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2 = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1, i2, i3 in T.grid(64, 28, 28): T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): @@ -589,15 +586,15 @@ def partitioned_concat_3( @T.prim_func def concat_func_3( - placeholder: T.Buffer[(50176,), "int8"], - placeholder_1: T.Buffer[(25088,), "int8"], - placeholder_2: T.Buffer[(25088,), "int8"], - T_concat: T.Buffer[(100352,), "int8"], + placeholder: T.Buffer[(1, 64, 28, 28), "int8"], + placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], + placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], + T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) + placeholder = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1 = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2 = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index fb1fb72eb82c..19bff28af672 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -24,12 +24,12 @@ @tvm.script.ir_module class Before: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -55,12 +55,12 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo @tvm.script.ir_module class After: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -86,15 +86,15 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo @tvm.script.ir_module class After_simplified: @T.prim_func - def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") From 7bdbf5d92867a455520336f40904f1d1f1795076 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Apr 2022 09:21:08 -0500 Subject: [PATCH 04/25] Remove preflattened_buffer from TVMScript stubs --- python/tvm/script/tir/__init__.pyi | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index e4513feb4323..ca8a5f20bb72 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -131,18 +131,6 @@ def store( ) -> None: ... def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ... -def preflattened_buffer( - buf: Buffer, - shape: Sequence[PrimExpr], - dtype: str = "float32", - data: Optional[Ptr] = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", -) -> Buffer: ... """ Intrinsics - tvm builtin From 06b0260d02fe198929250be0c267f7848f134a47 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Apr 2022 08:33:21 -0500 Subject: [PATCH 05/25] Removed additional preflattened usages after rebase --- src/printer/tir_text_printer.cc | 10 ---------- src/tir/transforms/legalize_packed_calls.cc | 4 ++-- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fe829016b6b5..d8e5c45d3194 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -152,16 +152,6 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } - if (op->preflattened_buffer_map.size() != 0) { - // print preflattened_buffer_map - std::vector preflattened_buffer_map_doc; - for (auto& v : op->preflattened_buffer_map) { - preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second)); - } - doc << Doc::Indent(2, Doc::NewLine() - << "preflattened_buffer_map = {" - << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}"); - } doc << PrintBody(op->body); return doc; } diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 344e6c7ae3cb..fed76876f6bf 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -74,9 +74,9 @@ class PackedCallLegalizer : public StmtExprMutator { tvm::runtime::Map::iterator param_buf_it; if (prim_func != nullptr) { auto param_var = prim_func->params[i - 1]; - param_buf_it = prim_func->preflattened_buffer_map.find(param_var); + param_buf_it = prim_func->buffer_map.find(param_var); } - if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) { + if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) { Buffer param = (*param_buf_it).second; PrimExpr shape = tvm::tir::Call( DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); From ae1627402341e47462aa230d06f4fba04871f5f9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Apr 2022 08:33:53 -0500 Subject: [PATCH 06/25] Updated tir::PrimFunc usage in cmsisnn contrib --- src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 210175817f9c..b322540e7bec 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -108,7 +108,7 @@ class RelayToTIRVisitor : public MixedModeMutator { } tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, - Map(), DictAttrs(dict_attrs)); + DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } From 5f47164c34ac345a222d60f8a8118eb9e44c988a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Apr 2022 08:45:08 -0500 Subject: [PATCH 07/25] Removing more usage of preflattened from python files --- .../unittest/test_aot_legalize_packed_call.py | 26 +++---- ...orm_convert_pool_allocations_to_offsets.py | 69 ------------------- .../unittest/test_tvmscript_error_report.py | 22 ------ .../unittest/test_tvmscript_syntax_sugar.py | 17 ----- 4 files changed, 10 insertions(+), 124 deletions(-) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index c7c0daa30e2f..756f3724bd99 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -26,15 +26,12 @@ class Module: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.Buffer[(1,), "float32"], ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: @@ -59,15 +56,12 @@ def tir_packed_call() -> None: class Expected: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.handle, ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index ce8675f575ee..98eb24e85f74 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -74,11 +74,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -89,13 +86,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -115,9 +108,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -164,13 +155,9 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - T.preflattened_buffer(T_cast_7, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -185,15 +172,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") - T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - T.preflattened_buffer(T_subtract_1, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -201,17 +183,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") - T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") - T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") - T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - T.preflattened_buffer(T_cast_21, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -275,11 +251,8 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -289,13 +262,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -314,13 +283,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -340,15 +305,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -385,13 +345,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -413,13 +369,9 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -427,17 +379,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -457,15 +403,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -485,15 +426,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -512,15 +448,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 070b5e85f174..e7b1ad201043 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -613,28 +613,6 @@ def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) -def preflattened_buffer_map_align_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], align="bar" - ) # check_error: align: want int or IntImm, got 'bar' - - -def test_preflattened_buffer_map_align(): - check_error(preflattened_buffer_map_align_nonint, 3) - - -def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], offset_factor="bar" - ) # check_error: offset_factor: want int or IntImm, got 'bar' - - -def test_preflattened_buffer_map_offset_factor(): - check_error(preflattened_buffer_map_offset_factor_nonint, 3) - - def strided_buffer_region(A: T.handle): # do not allow stride in buffer region A = T.match_buffer((128, 128), "int32") diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index a0964ea4d77c..1b1fcb06e1fb 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -181,23 +181,6 @@ def test_dynamic_shape_gemm(): assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) -@T.prim_func -def preflattened_buffer_map(A: T.handle, B: T.handle): - A_1 = T.match_buffer(A, [1]) - T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) - B_1 = T.match_buffer(B, [1]) - T.preflattened_buffer(B_1, [1]) - B_1[0] = A_1[0] - - -def test_preflattened_buffer_map(): - A_var = [ - k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" - ][0] - assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 - assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 - - @T.prim_func def match_buffer_int64(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32") From 83028e18470cefc5a7aeea4b2488e3664af343de Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 21 Apr 2022 10:30:32 -0500 Subject: [PATCH 08/25] Removing duplicate buffer names --- .../unittest/test_auto_scheduler_feature.py | 16 ++++----- tests/python/unittest/test_lower_build.py | 24 +++++++------ .../test_tir_transform_flatten_buffer.py | 32 ++++++++--------- .../test_tir_transform_loop_partition.py | 32 +++++++++-------- ...tir_transform_renormalize_split_pattern.py | 36 +++++++++---------- 5 files changed, 74 insertions(+), 66 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index b8da618044b7..a366ca761001 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -67,7 +67,7 @@ def test_cpu_matmul(): """ lowered IR: - + Placeholder: A, B parallel i.0 (0,32) parallel j.0 (0,64) @@ -78,8 +78,8 @@ def test_cpu_matmul(): """ # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc. - assert fequal(fea_dict[c_name + ".bytes"], math.log2(512**3 * 4 + 1)) - assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512**2 * 4 + 1)) + assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) + assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) @@ -209,14 +209,14 @@ def tir_matmul( ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A = T.buffer_decl([16384], dtype="float32", data=A.data) - B = T.buffer_decl([16384], dtype="float32", data=B.data) - C = T.buffer_decl([16384], dtype="float32", data=C.data) + A_flat = T.buffer_decl([16384], dtype="float32", data=A.data) + B_flat = T.buffer_decl([16384], dtype="float32", data=B.data) + C_flat = T.buffer_decl([16384], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = T.float32(0) + C_flat[x * 128 + y] = T.float32(0) for k in T.serial(128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] def test_primfunc_without_lowering(): diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index c8aae78a4559..665697b84be9 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -60,14 +60,16 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A = T.buffer_decl([16384], data=A.data) - B = T.buffer_decl([16384], data=B.data) - C = T.buffer_decl([16384], data=C.data) + A_flat = T.buffer_decl([16384], data=A.data) + B_flat = T.buffer_decl([16384], data=B.data) + C_flat = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = 0.0 + C_flat[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = ( + C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] + ) @tvm.script.ir_module @@ -80,14 +82,16 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.buffer_decl([16384], data=A.data) - B = T.buffer_decl([16384], data=B.data) - C = T.buffer_decl([16384], data=C.data) + A_flat = T.buffer_decl([16384], data=A.data) + B_flat = T.buffer_decl([16384], data=B.data) + C_flat = T.buffer_decl([16384], data=C.data) # body for x, y in T.grid(128, 128): - C[x * 128 + y] = 0.0 + C_flat[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] + C_flat[x * 128 + y] = ( + C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] + ) def test_lower_build_te_schedule(): diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 4cadfd19a05c..b7f18cede8c5 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -52,14 +52,14 @@ def compacted_elementwise_func( def flattened_elementwise_func( A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] ) -> None: - A = T.buffer_decl((16, 16), dtype="float32", data=A.data) - C = T.buffer_decl((16, 16), dtype="float32", data=C.data) + A_flat = T.buffer_decl((256,), dtype="float32", data=A.data) + C_flat = T.buffer_decl((256,), dtype="float32", data=C.data) for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 + B_new[j] = A_flat[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 + C_flat[((i * 16) + j)] = B_new[j] * 2.0 @T.prim_func @@ -85,8 +85,8 @@ def compacted_gpu_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), " @T.prim_func def flattened_gpu_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: - A = T.buffer_decl([256], dtype="float32", data=A.data) - C = T.buffer_decl([256], dtype="float32", data=C.data) + A_flat = T.buffer_decl([256], dtype="float32", data=A.data) + C_flat = T.buffer_decl([256], dtype="float32", data=C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -97,9 +97,9 @@ def flattened_gpu_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), " T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): - B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 + B[j] = A_flat[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): - C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 + C_flat[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 @T.prim_func @@ -128,15 +128,15 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, [n, m], "float32") C = T.match_buffer(c, [n, m], "float32") - A = T.buffer_decl([n * m], "float32", data=A.data) - C = T.buffer_decl([n * m], "float32", data=C.data) + A_flat = T.buffer_decl([n * m], "float32", data=A.data) + C_flat = T.buffer_decl([n * m], "float32", data=C.data) for i in range(0, n): B = T.allocate([m], "float32", "global") for j in range(0, m): - B[j] = A[i * m + j] + 1.0 + B[j] = A_flat[i * m + j] + 1.0 for j in range(0, m): - C[i * m + j] = B[j] * 2.0 + C_flat[i * m + j] = B[j] * 2.0 @T.prim_func @@ -225,16 +225,16 @@ def compacted_strided_buffer_func( def flattened_strided_buffer_func( A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] ) -> None: - A = T.buffer_decl([256], dtype="float32", data=A.data) - C = T.buffer_decl([256], dtype="float32", data=C.data) + A_flat = T.buffer_decl([256], dtype="float32", data=A.data) + C_flat = T.buffer_decl([256], dtype="float32", data=C.data) for i0 in T.serial(0, 4): B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): - B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 + B_new[i1 * 17 + j] = A_flat[i0 * 64 + i1 * 16 + j] + 1.0 for i1 in T.serial(0, 4): for j in T.serial(0, 16): - C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 + C_flat[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 63f29e90aea6..524a879666fa 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -572,16 +572,16 @@ def partitioned_concat_3( placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - placeholder = T.buffer_decl([50176], "int8", data=placeholder.data) - placeholder_1 = T.buffer_decl([25088], "int8", data=placeholder_1.data) - placeholder_2 = T.buffer_decl([25088], "int8", data=placeholder_2.data) - T_concat = T.buffer_decl([100352], "int8", data=T_concat.data) + placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1, i2, i3 in T.grid(64, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1_flat[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T_concat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2_flat[i1 * 784 + i2 * 28 + i3] @T.prim_func @@ -591,18 +591,22 @@ def concat_func_3( placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], T_concat: T.Buffer[(1, 128, 28, 28), "int8"], ) -> None: - placeholder = T.buffer_decl([50176], "int8", data=placeholder.data) - placeholder_1 = T.buffer_decl([25088], "int8", data=placeholder_1.data) - placeholder_2 = T.buffer_decl([25088], "int8", data=placeholder_2.data) - T_concat = T.buffer_decl([100352], "int8", data=T_concat.data) + placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data) + placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data) + placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data) + T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_2_flat[ + i1 * 784 + i2 * 28 + i3 - 75264 + ] if 64 <= i1 and i1 < 96: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_1_flat[ + i1 * 784 + i2 * 28 + i3 - 50176 + ] if i1 < 64: - T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] + T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] def test_condition_mutually_exclusive(): diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index 19bff28af672..fbb0a8ec3395 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -27,9 +27,9 @@ class Before: def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) - weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) - conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -43,13 +43,13 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module @@ -58,9 +58,9 @@ class After: def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) - weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) - conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -74,13 +74,13 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module @@ -92,9 +92,9 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") - inputs = T.buffer_decl([8192], dtype="float32", data=inputs.data) - weight = T.buffer_decl([2097152], dtype="float32", data=weight.data) - conv2d_transpose_nhwc = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) + inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data) + weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data) + conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") @@ -105,13 +105,13 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] + conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg # fmt: on From 077327919c08e7d17100f675d9ce60189a5b9d49 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 21 Apr 2022 13:52:29 -0500 Subject: [PATCH 09/25] Corrected linting errors --- tests/python/unittest/test_auto_scheduler_feature.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index a366ca761001..d9ef621db427 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -78,8 +78,8 @@ def test_cpu_matmul(): """ # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc. - assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) - assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) + assert fequal(fea_dict[c_name + ".bytes"], math.log2(512**3 * 4 + 1)) + assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512**2 * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) From 877cc24213ad75cd6879fe70813da36eb99a7290 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 25 Apr 2022 09:49:19 -0500 Subject: [PATCH 10/25] Updated BufferAllocationLocator to ignore aliases of arg buffers --- .../transforms/plan_update_buffer_allocation_location.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 81dfceb40d32..e9869acdf156 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -35,17 +35,18 @@ class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { Map> buffer_lca = DetectBufferAccessLCA(func); - std::unordered_set arg_buffers; + + std::unordered_set arg_buffer_vars; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; - arg_buffers.emplace(buffer.get()); + arg_buffer_vars.emplace(buffer->data.get()); buffer_data_to_buffer_.Set(buffer->data, buffer); } // create buffers to be allocated at each stmts for (const auto& kv : buffer_lca) { const Buffer& buffer = kv.first; const StmtNode* stmt = kv.second.get(); - if (arg_buffers.count(buffer.get())) { + if (arg_buffer_vars.count(buffer->data.get())) { continue; } alloc_buffers_[stmt].push_back(buffer); From 2fde1a923025dd9cd4f2158c8ab7650d13e47f27 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 May 2022 12:26:47 -0500 Subject: [PATCH 11/25] Replaced more preflatten occurrences --- .../test_ethosu/test_encode_constants.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index f26cbf6871a8..c2514331af62 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -70,7 +70,7 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class WeightStreamOnlyU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition @@ -82,8 +82,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_encoded_5 = T.buffer_decl([32], dtype="uint8") buffer_encoded_6 = T.buffer_decl([160], dtype="uint8") buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) @@ -191,14 +191,14 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class RereadWeightsU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition placeholder_encoded = T.buffer_decl([368], dtype="uint8") placeholder_encoded_1 = T.buffer_decl([96], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_global_1 = T.buffer_decl([368], dtype="uint8", data=placeholder_global.data) @@ -294,7 +294,7 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class DirectReadOnlyU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition @@ -302,8 +302,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -409,7 +409,7 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition @@ -423,8 +423,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") placeholder_encoded = T.buffer_decl([608], dtype="uint8") placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) From c95b1864baa44af6d20d89064f3d417aa4605d7c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 May 2022 12:32:28 -0500 Subject: [PATCH 12/25] Removed preflatten usage from merge --- src/tir/usmp/transform/create_io_allocates.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc index 59eee961632d..cf754131776c 100644 --- a/src/tir/usmp/transform/create_io_allocates.cc +++ b/src/tir/usmp/transform/create_io_allocates.cc @@ -195,9 +195,8 @@ IRModule IOAllocateCreator::operator()() { } } const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); - mod_->Update(gv, - PrimFunc(new_main_params, main_body, main_func_->ret_type, main_func_->buffer_map, - main_func_->preflattened_buffer_map, main_func_->attrs, main_func_->span)); + mod_->Update(gv, PrimFunc(new_main_params, main_body, main_func_->ret_type, + main_func_->buffer_map, main_func_->attrs, main_func_->span)); return mod_; } From df85734d8d35cbf5eeb90e7ac45028d1c7ea42dc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 May 2022 14:01:57 -0500 Subject: [PATCH 13/25] T.handle -> T.Buffer in PrimFunc args for AOT test This was incorrectly done in the first pass through, caught by CI. --- tests/python/unittest/test_aot_legalize_packed_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 756f3724bd99..c0b41e3b951a 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -59,7 +59,7 @@ def tvm_test_cpacked( A: T.Buffer[(1,), "float32"], B: T.Buffer[(1,), "float32"], C: T.Buffer[(1,), "float32"], - device_context: T.handle, + device_context: T.Buffer[(1,), "float32"], ) -> T.handle: T.evaluate(C.data) From 65f3ebd5e371abba25a655068efab2ca355c7a6a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 27 Sep 2022 19:25:43 -0500 Subject: [PATCH 14/25] More removal of preflattened instances --- include/tvm/script/ir_builder/tir/frame.h | 3 - include/tvm/script/ir_builder/tir/ir.h | 20 ------ python/tvm/script/ir_builder/tir/ir.py | 68 ------------------- src/relay/backend/aot/aot_lower_main.cc | 2 +- src/script/ir_builder/tir/frame.cc | 1 - src/script/ir_builder/tir/ir.cc | 22 ------ src/tir/contrib/ethosu/passes.cc | 1 - .../test_ethosu/test_merge_constants.py | 6 +- .../test_tir_transform_loop_partition.py | 14 ++-- .../unittest/test_tvmscript_ir_builder_tir.py | 4 -- 10 files changed, 12 insertions(+), 129 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index aa2386e7f1e4..c55a27a0fd42 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -75,8 +75,6 @@ class PrimFuncFrameNode : public TIRFrameNode { Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ Map buffer_map; - /*! \brief The buffer map prior to flattening. */ - Map preflattened_buffer_map; /*! \brief Additional attributes storing the meta-data */ Optional> attrs; /*! \brief The variable map bound to thread env. */ @@ -90,7 +88,6 @@ class PrimFuncFrameNode : public TIRFrameNode { v->Visit("args", &args); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); - v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("env_threads", &env_threads); v->Visit("root_alloc_buffers", &root_alloc_buffers); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 7460099f9448..59006bbe927b 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -114,26 +114,6 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = Data int align = -1, int offset_factor = 0, String buffer_type = "default", Array axis_separators = {}); -/*! - * \brief The pre-flattened buffer statement. - * \param postflattened_buffer The original buffer to be flattened. - * \param shape The type of the buffer prior to flattening. - * \param dtype The data type in the content of the buffer. - * \param data The pointer to the head of the data. - * \param strides The strides of each dimension. - * \param elem_offset The offset in terms of number of dtype elements (including lanes). - * \param storage_scope The optional storage scope of buffer data pointer. - * \param align The alignment requirement of data pointer in bytes. - * \param offset_factor The factor of elem_offset field. - * \param buffer_type The buffer type. - * \param axis_separators The separators between input axes when generating flattened output axes. - */ -void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, - DataType dtype = DataType::Float(32), Optional data = NullOpt, - Array strides = {}, PrimExpr elem_offset = PrimExpr(), - String storage_scope = "global", int align = -1, int offset_factor = 0, - String buffer_type = "default", Array axis_separators = {}); - /*! * \brief The block declaration statement. * \param name The name of the block. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 4ec1511f2907..ef96cbdee885 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -274,74 +274,6 @@ def match_buffer( ) -def preflattened_buffer( - postflattened: Buffer, - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], - dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: List[int] = None, -) -> None: - """The pre-flattened buffer statement. - - Parameters - ---------- - postflattened : Buffer - The original buffer to be flattened. - - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] - The type of the buffer prior to flattening. - - dtype : str - The data type in the content of the buffer. - - data : Var - The pointer to the head of the data. - - strides : List[PrimExpr] - The strides of each dimension. - - elem_offset : PrimExpr - The offset in terms of number of dtype elements (including lanes). - - scope : str - The optional storage scope of buffer data pointer. - - align : int - The alignment requirement of data pointer in bytes. - - offset_factor : int - The factor of elem_offset field. - - buffer_type : str - The buffer type. - - axis_separators : List[int] - The separators between input axes when generating flattened output axes. - """ - shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - if strides is None: - strides = [] - _ffi_api.PreflattenedBuffer( # type: ignore[attr-defined] # pylint: disable=no-member - postflattened, - shape, - dtype, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - ) - - def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: """The block declaration statement. diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index ce72595dc10b..a174ae2c3e59 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -498,7 +498,7 @@ class AOTMainLowerer : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, DictAttrs(dict_attrs)); } diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index aa9efa653f71..2050097aeb24 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -34,7 +34,6 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/AsStmt(stmts), /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, - /*preflattened_buffer_map=*/preflattened_buffer_map, /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); func = tvm::tir::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 6be6e2619fea..a63f9a0455f1 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -58,7 +58,6 @@ PrimFuncFrame PrimFunc() { n->args.clear(); n->ret_type = NullOpt; n->buffer_map.clear(); - n->preflattened_buffer_map.clear(); n->attrs = NullOpt; n->env_threads.clear(); n->root_alloc_buffers.clear(); @@ -137,26 +136,6 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optio return buffer; } -void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, DataType dtype, - Optional data, Array strides, PrimExpr elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators) { - PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); - for (auto const& p : frame->buffer_map) { - if (p.second.same_as(postflattened_buffer)) { - String buffer_name(postflattened_buffer->name + "_preflatten"); - Buffer buffer = - BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset, - storage_scope, align, offset_factor, buffer_type_str, axis_separators); - details::Namer::Name(buffer, buffer_name); - frame->preflattened_buffer_map.Set(p.first, buffer); - return; - } - } - LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name - << " does not exist."; -} - BlockFrame Block(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; @@ -595,7 +574,6 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index f34a314fe7eb..ba4054ac8ce5 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -522,7 +522,6 @@ class MergeConstantsMutator : public StmtExprMutator { prim_func_node->body = std::move(new_body); prim_func_node->buffer_map = std::move(new_buffer_map); prim_func_node->params = std::move(new_params); - prim_func_node->preflattened_buffer_map = {}; PrimFunc f{GetRef(prim_func_node)}; // Add the new const dict as an attribute diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 337b5c70d125..76d208f89bbd 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -399,12 +399,12 @@ def test_read_from_the_same_buffer(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[[1, 16, 16, 32], "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[[1, 16, 16, 8], "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 1a9656583c88..6a111a4bd77c 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -622,9 +622,9 @@ def test_condition_mutually_exclusive(): def test_loop_partition_unroll_hint(): @T.prim_func - def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: - T.preflattened_buffer(A, [1, 3, 224, 224], "int8", data=A.data) - T.preflattened_buffer(B, [1, 224, 7, 16], "int8", data=B.data) + def main(A: T.Buffer[[1, 3, 224, 224], "int8"], B: T.Buffer[[1, 224, 7, 16], "int8"]) -> None: + A = T.buffer_decl(150528, "int8", data=A.data) + B = T.buffer_decl(25088, "int8", data=B.data) for ax0 in T.serial( 112, annotations={"pragma_loop_partition_hint": True}, @@ -634,9 +634,11 @@ def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 - 3] @T.prim_func - def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None: - T.preflattened_buffer(A, [1, 3, 224, 224], dtype="int8", data=A.data) - T.preflattened_buffer(B, [1, 224, 7, 16], dtype="int8", data=B.data) + def partitioned_main( + A: T.Buffer[[1, 3, 224, 224], "int8"], B: T.Buffer[[1, 224, 7, 16], "int8"] + ) -> None: + A = T.buffer_decl(150528, dtype="int8", data=A.data) + B = T.buffer_decl(25088, dtype="int8", data=B.data) # body for ax1, ax2, ax3 in T.grid(224, 7, 16): if 3 <= ax2 and ax3 < 3: diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index dbc9b594fb87..fdf811af6522 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -60,7 +60,6 @@ def test_ir_builder_tir_primfunc_complete(): T.func_attr({"key": "value"}) T.func_ret(tvm.ir.PrimType("int64")) buffer_d = T.match_buffer(d, (64, 64), "int64") - T.preflattened_buffer(e, (32, 32), "int8", data=e.data) T.evaluate(0) # the prim_func generated by IRBuilder @@ -83,9 +82,6 @@ def test_ir_builder_tir_primfunc_complete(): body=tir.Evaluate(0), ret_type=tvm.ir.PrimType("int64"), buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer}, - preflattened_buffer_map={ - e_handle: tir.decl_buffer((32, 32), "int8", name="e_preflatten", data=e_buffer.data) - }, attrs=tvm.ir.make_node("DictAttrs", key="value"), ) From fdbc64dd2b3a0d711d800bfcc9deff2d644c78dd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Sep 2022 11:40:16 -0500 Subject: [PATCH 15/25] lint fixes --- python/tvm/script/ir_builder/tir/ir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ef96cbdee885..5b9a2549ae72 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1736,7 +1736,6 @@ def f(): "func_attr", "func_ret", "match_buffer", - "preflattened_buffer", "block", "init", "where", From bc99a5501fda21768215fbc46dc1c692c631d719 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Nov 2022 10:16:13 -0600 Subject: [PATCH 16/25] Update following merge * Remove additional preflatten usage * Use tuple instead of list for `T.Buffer`, as lists now throw * A couple numeric updates --- .../test_ethosu/test_merge_constants.py | 2 +- .../test_tir_transform_flatten_buffer.py | 4 ++-- .../test_tir_transform_loop_partition.py | 20 ++++++++----------- .../test_tir_transform_thread_sync.py | 4 ++-- ...orm_convert_pool_allocations_to_offsets.py | 3 --- .../unittest/test_tvmscript_ir_builder_tir.py | 1 - 6 files changed, 13 insertions(+), 21 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 76d208f89bbd..9be404658663 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -399,7 +399,7 @@ def test_read_from_the_same_buffer(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(input_placeholder: T.Buffer[[1, 16, 16, 32], "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[[1, 16, 16, 8], "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 49e1407f2c47..513e04dc2090 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -186,8 +186,8 @@ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]): - A = T.buffer_decl(128, dtype="float32", data=input_A.data) - C = T.buffer_decl(128, dtype="float32", data=input_C.data) + A = T.buffer_decl(256, dtype="float32", data=input_A.data) + C = T.buffer_decl(256, dtype="float32", data=input_C.data) for i0 in T.serial(0, 4): B_new_data = T.allocate([68], "float32", scope="global") B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index f2e6f0612761..fe48aa7d8fd4 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -629,9 +629,11 @@ def test_condition_mutually_exclusive(): def test_loop_partition_unroll_hint(): @T.prim_func - def main(A: T.Buffer[[1, 3, 224, 224], "int8"], B: T.Buffer[[1, 224, 7, 16], "int8"]) -> None: - A = T.buffer_decl(150528, "int8", data=A.data) - B = T.buffer_decl(25088, "int8", data=B.data) + def main( + A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"] + ) -> None: + A = T.buffer_decl(150528, "int8", data=A_arg.data) + B = T.buffer_decl(25088, "int8", data=B_arg.data) for ax0 in T.serial( 112, annotations={"pragma_loop_partition_hint": True}, @@ -642,10 +644,10 @@ def main(A: T.Buffer[[1, 3, 224, 224], "int8"], B: T.Buffer[[1, 224, 7, 16], "in @T.prim_func def partitioned_main( - A: T.Buffer[[1, 3, 224, 224], "int8"], B: T.Buffer[[1, 224, 7, 16], "int8"] + A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"] ) -> None: - A = T.buffer_decl(150528, dtype="int8", data=A.data) - B = T.buffer_decl(25088, dtype="int8", data=B.data) + A = T.buffer_decl(150528, dtype="int8", data=A_arg.data) + B = T.buffer_decl(25088, dtype="int8", data=B_arg.data) # body for ax1, ax2, ax3 in T.grid(224, 7, 16): if 3 <= ax2 and ax3 < 3: @@ -691,8 +693,6 @@ def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: @T.prim_func def after(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: - T.preflattened_buffer(A, [160], dtype="int32", data=A.data) - T.preflattened_buffer(B, [160], dtype="int32", data=B.data) for i in T.serial(10, annotations={"key": "value"}): B[i] = A[i] + 1 for i in T.serial(140, annotations={"key": "value"}): @@ -740,10 +740,6 @@ def after( placeholder_2: T.Buffer[25088, "int8"], T_concat: T.Buffer[100352, "int8"], ) -> None: - T.preflattened_buffer(placeholder, [50176], dtype="int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [25088], dtype="int8", data=placeholder_1.data) - T.preflattened_buffer(placeholder_2, [25088], dtype="int8", data=placeholder_2.data) - T.preflattened_buffer(T_concat, [100352], dtype="int8", data=T_concat.data) for _ in T.serial(1, annotations={"preserve_unit_loop": True}): for i1, i2, i3 in T.grid(64, 28, 28): T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 18607ca1a005..89deadde4de0 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -98,10 +98,10 @@ def ir(A, B): @tvm.testing.requires_cuda def test_sync_read_thread_id_independent_location(): @T.prim_func - def func(p0: T.Buffer[2, "float32"], p1: T.Buffer[2, "float32"]) -> None: + def func(p0_arg: T.Buffer[(1, 2, 1, 1), "float32"], p1: T.Buffer[2, "float32"]) -> None: threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") - T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data) + p0 = T.buffer_decl([2], dtype="float32", data=p0_arg.data) T.launch_thread(blockIdx_x, 8) result_local = T.alloc_buffer([1], dtype="float32", scope="local") temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index d586426931b9..d1f86814e7d6 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -561,9 +561,6 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr[T.uint8]) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) - T.preflattened_buffer( - global_workspace_1_buffer_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 - ) dense_let = T.buffer_decl([10], "int32") with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")): T.evaluate( diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index fdf811af6522..f8ace3f22db8 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -41,7 +41,6 @@ def test_ir_builder_tir_primfunc_base(): body=tir.Evaluate(0), ret_type=None, buffer_map=None, - preflattened_buffer_map=None, attrs=None, ) From ff9fb58f705a25e02c142a35728fa6115bb357c5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Nov 2022 10:17:08 -0600 Subject: [PATCH 17/25] Directly write vector function instead of relying on tvm.lower --- .../unittest/test_arith_domain_touched.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 3641f06ab8a2..81bda403274f 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -30,18 +30,6 @@ def scalar_func(a: T.handle, b: T.handle): A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1] -@T.prim_func -def vector_func(a: T.handle, b: T.handle): - n = T.var("int32") - m = 128 - A = T.match_buffer(a, (n, m)) - B = T.match_buffer(b, (n, m)) - - for i in T.serial(n): - for j in T.vectorized(m): - A[i, j] = A[i, j] + B[i, j] - - def test_domain_touched(): func = scalar_func a, b = [func.buffer_map[var] for var in func.params] @@ -81,7 +69,17 @@ def test_domain_touched(): def test_domain_touched_vector(): - func = tvm.lower(vector_func)["main"] + m = tvm.runtime.convert(128) + + @T.prim_func + def func(a: T.handle, b: T.handle): + n = T.var("int32") + A = T.match_buffer(a, (n * m,)) + B = T.match_buffer(b, (n * m,)) + + for i in T.serial(n): + A[i * m : (i + 1) * m] = A[i * m : (i + 1) * m] + B[i * m : (i + 1) * m] + a, b = [func.buffer_map[var] for var in func.params] assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 From cb15f962577e6162d9600a733923d6f8e064bb4f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Nov 2022 11:26:22 -0600 Subject: [PATCH 18/25] Updates to ethos-u constant encoding to avoid breakage --- .../backend/contrib/ethosu/tir/passes.py | 40 +++++++++++-------- .../test_ethosu/test_encode_constants.py | 8 ++-- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 598f6b76c8f8..083a3a1433b9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -326,7 +326,7 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - def collect_encoding_definitions(stmt, old_buffer_to_const): + def collect_encoding_definitions(stmt, old_buffer_var_to_const): # Map from copy destination to copy source. copy_map = {} # List of buffer copies that occurred @@ -361,6 +361,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): dtype=str(encoded_constants.dtype), name=old_buffer.name + "_encoded", scope=old_buffer.scope(), + data=old_buffer.data, ) constant_buffer_replacements.append( @@ -375,7 +376,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): """Encode the weights or align the bias either for one or two cores, depending on the variant.""" - constant = old_buffer_to_const[buffer1] + constant = old_buffer_var_to_const[buffer1.data] # If we have just one core, encode the whole constant if buffer2 is None: @@ -470,7 +471,12 @@ def _visit(stmt): } def transform_stmt( - stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx + stmt, + buf_remap, + var_remap, + pointer_to_buffer, + new_buffer_var_to_const, + new_buffer_to_split_idx, ): def _visit_rewrite(stmt): if isinstance(stmt, tvm.tir.Call): @@ -484,7 +490,7 @@ def _visit_rewrite(stmt): # encoded buffer, the current should be a length. if ( isinstance(prev_arg, tvm.tir.BufferLoad) - and prev_arg.buffer in new_buffer_to_const + and prev_arg.buffer.data in new_buffer_var_to_const ): buffer_size = np.prod(list(prev_arg.buffer.shape)) arg = buffer_size @@ -556,25 +562,25 @@ def _visit_rewrite(stmt): def _ftransform(f, mod, ctx): # Step 0: Unpack the constant dictionary in terms of the # functions buffers. - old_buffer_to_const = {} + old_buffer_var_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - old_buffer_to_const[f.buffer_map[param]] = const_dict[i] + old_buffer_var_to_const[f.buffer_map[param].data] = const_dict[i] # Step 1: Collect information on the buffers that will be # replaced by encodings. - buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const) + buffer_information = collect_encoding_definitions(f.body, old_buffer_var_to_const) # Step 2: Generate variable/buffer remaps, based on the # collected information. buf_remap = {} - new_buffer_to_const = {} + new_buffer_var_to_const = {} new_buffer_to_split_idx = {} # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: buf_remap[info["old_buffer"]] = info["new_buffer"] - new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] + new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"] if info["split_idx"]: new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"] @@ -596,8 +602,10 @@ def _ftransform(f, mod, ctx): scope=copy_dest.scope(), ) buf_remap[copy_dest] = new_dest - if copy_source in new_buffer_to_const: - new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] + if copy_source.data in new_buffer_var_to_const: + new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[ + copy_source.data + ] if copy_source in new_buffer_to_split_idx: new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source] @@ -614,7 +622,7 @@ def _ftransform(f, mod, ctx): buf_remap, var_remap, pointer_to_buffer, - new_buffer_to_const, + new_buffer_var_to_const, new_buffer_to_split_idx, ) @@ -625,10 +633,10 @@ def _ftransform(f, mod, ctx): if buffer in buf_remap: buffer = buf_remap[buffer] - if buffer in new_buffer_to_const: - new_const_dict[i] = new_buffer_to_const[buffer].flatten() - elif buffer in old_buffer_to_const: - new_const_dict[i] = old_buffer_to_const[buffer].flatten() + if buffer.data in new_buffer_var_to_const: + new_const_dict[i] = new_buffer_var_to_const[buffer.data].flatten() + elif buffer.data in old_buffer_var_to_const: + new_const_dict[i] = old_buffer_var_to_const[buffer.data].flatten() new_buffer_map[param] = buffer diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 9f6e9c68f95f..603a6f22d361 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -152,12 +152,12 @@ def _get_func(): @tvm.script.ir_module class RereadWeightsU55: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([384], "uint8") - placeholder = T.buffer_decl([8192], "int8", data=placeholder.data) - ethosu_write = T.buffer_decl([2048], "int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) # body p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([384], "uint8", data=p1_data) @@ -361,7 +361,7 @@ def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], buffer_encoded: T buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") buffer11 = T.buffer_decl([2048], "int8") - placeholder = T.buffer_decl([2048], "int8", data=placeholder.data) + placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data) # body p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([112], "uint8", data=p1_data) From 9d7853f5e105790d28417baf9a2ee09b135884af Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Nov 2022 16:39:34 -0600 Subject: [PATCH 19/25] A few more ethos-u updates There were a couple of spots that needed to be aware of flattened buffers being represented as buffer aliases. --- .../backend/contrib/ethosu/tir/passes.py | 33 +++++++++++++++++-- .../test_ethosu/test_remove_concatenates.py | 10 +++--- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 083a3a1433b9..e15d126dd969 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -361,7 +361,6 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): dtype=str(encoded_constants.dtype), name=old_buffer.name + "_encoded", scope=old_buffer.scope(), - data=old_buffer.data, ) constant_buffer_replacements.append( @@ -559,7 +558,25 @@ def _visit_rewrite(stmt): ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], ) + def _collect_parameter_buffer_aliases(prim_func): + buffer_vars = {} + for param in prim_func.params: + if param in prim_func.buffer_map: + buf = prim_func.buffer_map[param] + buffer_vars[buf.data] = {buf} + + def visit(node): + if isinstance(node, (tvm.tir.BufferStore, tvm.tir.BufferLoad, tvm.tir.DeclBuffer)): + buf = node.buffer + if buf.data in buffer_vars: + buffer_vars[buf.data].add(buf) + + tvm.tir.stmt_functor.post_order_visit(prim_func.body, visit) + return buffer_vars + def _ftransform(f, mod, ctx): + param_buffer_var_usage = _collect_parameter_buffer_aliases(f) + # Step 0: Unpack the constant dictionary in terms of the # functions buffers. old_buffer_var_to_const = {} @@ -577,9 +594,19 @@ def _ftransform(f, mod, ctx): new_buffer_var_to_const = {} new_buffer_to_split_idx = {} + def define_remap(old_buf, new_buf): + try: + old_buffers = param_buffer_var_usage[old_buf.data] + except KeyError: + old_buffers = [old_buf] + + for old_buffer in old_buffers: + buf_remap[old_buffer] = new_buf + # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: - buf_remap[info["old_buffer"]] = info["new_buffer"] + define_remap(info["old_buffer"], info["new_buffer"]) + new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"] if info["split_idx"]: @@ -601,7 +628,7 @@ def _ftransform(f, mod, ctx): name=copy_dest.name, scope=copy_dest.scope(), ) - buf_remap[copy_dest] = new_dest + define_remap(copy_dest, new_dest) if copy_source.data in new_buffer_var_to_const: new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[ copy_source.data diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index 1081df16c1f2..379a35b1b4a4 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -30,9 +30,14 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(input_placeholder: T.Buffer[(1, 8, 12, 16), "int8"], input_placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], input_T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,8,12,16), "int8"], input_placeholder_1: T.Buffer[(1,8,10,16), "int8"], input_T_concat: T.Buffer[(1,8,32,16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + placeholder = T.buffer_decl(1536, dtype="int8", data=input_placeholder.data) + placeholder_1 = T.buffer_decl(1280, dtype="int8", data=input_placeholder_1.data) + T_concat = T.buffer_decl(4096, dtype="int8", data=input_T_concat.data) + buffer = T.buffer_decl([2992], "uint8") buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([2992], "uint8") @@ -41,9 +46,6 @@ def main(input_placeholder: T.Buffer[(1, 8, 12, 16), "int8"], input_placeholder_ buffer_5 = T.buffer_decl([160], "uint8") buffer_6 = T.buffer_decl([2992], "uint8") buffer_7 = T.buffer_decl([160], "uint8") - placeholder = T.buffer_decl([1536], "int8", data=input_placeholder.data) - placeholder_1 = T.buffer_decl([1280], "int8", data=input_placeholder_1.data) - T_concat = T.buffer_decl([1536], "int8", data=input_T_concat.data) # body T_concat_1_data = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) T_concat_1 = T.buffer_decl([2816], "int8", data=T_concat_1_data) From bd4b2dc403cc3a55f82bf40027b521b653ab85f9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Nov 2022 08:22:38 -0600 Subject: [PATCH 20/25] Updates following latest merge --- tests/python/contrib/test_ethosu/test_encode_constants.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 861ac47611fe..61128da71c37 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -352,7 +352,7 @@ def _get_func(): @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([112], "uint8") @@ -362,6 +362,7 @@ def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], ethosu_write: T.Buffer[(2048 buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") ifm = T.buffer_decl([8192], "int8", data=input_ifm.data) + ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data) # body p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([112], "uint8", data=p1_data) @@ -384,11 +385,12 @@ def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], ethosu_write: T.Buffer[(2048 @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition ifm = T.buffer_decl([8192], dtype="int8", data=input_ifm.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) buffer1 = T.buffer_decl([128], dtype="uint8") buffer2 = T.buffer_decl([128], dtype="uint8") buffer3 = T.buffer_decl([128], dtype="uint8") From d201e62a1a9b8bd7282b97f1c02c65c2ef74b407 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Nov 2022 08:24:14 -0600 Subject: [PATCH 21/25] Fixed updates in TVMScript for test_replace_conv2d --- .../python/contrib/test_ethosu/test_replace_conv2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index c48cf946fadc..46c6976567c8 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -388,15 +388,15 @@ def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([80], "uint8") buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") - placeholder_5 = T.buffer_decl([192], 'int8', data=placeholder_5.data) - ethosu_write_1 = T.buffer_decl([512], 'int8', data=ethosu_write_1.data) + placeholder_5 = T.buffer_decl([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.buffer_decl([512], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) ethosu_write_2 = T.buffer_decl([1536], "int8", data=ethosu_write_2_data) @@ -464,8 +464,8 @@ def main(input_placeholder: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write: buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - placeholder = T.buffer_decl([192], 'int8', data=placeholder.data) - ethosu_write = T.buffer_decl([8192], 'int8', data=ethosu_write.data) + placeholder = T.buffer_decl([192], 'int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl([8192], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) From c37ec25e205ba687819ea6ef99765b3b0518a00f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Nov 2022 08:26:13 -0600 Subject: [PATCH 22/25] Fixing breakage in test_hoist_allocates.py --- .../test_ethosu/test_hoist_allocates.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_hoist_allocates.py b/tests/python/contrib/test_ethosu/test_hoist_allocates.py index 5bab415a4c0d..1508aa441c3b 100644 --- a/tests/python/contrib/test_ethosu/test_hoist_allocates.py +++ b/tests/python/contrib/test_ethosu/test_hoist_allocates.py @@ -106,15 +106,15 @@ def test_double_convolution(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(1, 27, 42, 3), "int8"], placeholder_encoded: T.Buffer[(3, 3, 2, 3), "uint8"], placeholder_encoded_1: T.Buffer[(3, 10), "uint8"], placeholder_encoded_2: T.Buffer[(3, 3, 2, 3), "uint8"], placeholder_encoded_3: T.Buffer[(3, 10), "uint8"], ethosu_write: T.Buffer[(1, 27, 42, 3), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 27, 42, 3), "int8"], input_placeholder_encoded: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_1: T.Buffer[(3, 10), "uint8"], input_placeholder_encoded_2: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_3: T.Buffer[(3, 10), "uint8"], input_ethosu_write: T.Buffer[(1, 27, 42, 3), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.buffer_decl([3402], dtype="int8", data=placeholder.data) - placeholder_encoded = T.buffer_decl([128], dtype="int8", data=placeholder_encoded.data) - placeholder_encoded_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_encoded_1.data) - placeholder_encoded_2 = T.buffer_decl([128], dtype="int8", data=placeholder_encoded_2.data) - placeholder_encoded_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_encoded_3.data) - ethosu_write = T.buffer_decl([3402], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([3402], dtype="int8", data=input_placeholder.data) + placeholder_encoded = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded.data) + placeholder_encoded_1 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_1.data) + placeholder_encoded_2 = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded_2.data) + placeholder_encoded_3 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_3.data) + ethosu_write = T.buffer_decl([3402], dtype="int8", data=input_ethosu_write.data) # body placeholder_global_data = T.allocate([128], "uint8", "global") placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) @@ -150,10 +150,10 @@ def test_identities(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(1, 2, 3, 4), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 2, 3, 4), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.buffer_decl([24], dtype="int8", data=placeholder.data) + placeholder = T.buffer_decl([24], dtype="int8", data=input_placeholder.data) # body ethosu_write_data = T.allocate([12], "int8", "global") ethosu_write = T.buffer_decl([12], "int8", data=ethosu_write_data) @@ -187,11 +187,11 @@ def test_outer_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.buffer_decl([8192], dtype="int8", data=placeholder.data) - ethosu_write = T.buffer_decl([2048], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body with T.allocate([128], "uint8", "global") as placeholder_global_data: placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) @@ -237,11 +237,11 @@ def test_allocate_without_seq_stmt(): @tvm.script.ir_module class Module: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.buffer_decl([8192], dtype="int8", data=placeholder.data) - ethosu_write = T.buffer_decl([2048], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data) # body placeholder_global_data = T.allocate([128], "uint8", "global") placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) From 299b7b390b71f7b0adb0836faf609e613f5c9457 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Nov 2022 08:36:09 -0600 Subject: [PATCH 23/25] Resolve breakage in test_merge_constants.py --- .../test_ethosu/test_merge_constants.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index bcfbb857a5db..ed1927b849d6 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -419,9 +419,12 @@ def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], buffer1: T.Buffer @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -446,12 +449,12 @@ def test_arbitrary_argument_order(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -473,9 +476,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -509,12 +515,12 @@ def test_arbitrary_argument_order_const_split(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -536,9 +542,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data) + ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) @@ -572,12 +581,12 @@ def test_arbitrary_argument_order_const_split_mixed(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data) # body p1_data = T.allocate([368], "uint8", "global") p1 = T.buffer_decl([368], "uint8", data=p1_data) @@ -599,9 +608,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None: + def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data) + ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data) # body p1_data = T.allocate([464], "uint8", "global") p1 = T.buffer_decl([464], "uint8", data=p1_data) From de1d5fefa73145010b39e2d9c271a20849396898 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Nov 2022 10:32:53 -0600 Subject: [PATCH 24/25] Remove some debug code that broke PassContext Forwarding the instrumentation when evaluating FoldConstants helped in debugging, but shouldn't have made it in. --- src/relay/transforms/fold_constant.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 1ddb0e44eac1..9dec840be0a7 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -259,9 +259,7 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - auto context = transform::PassContext::Create(); - context->instruments = transform::PassContext::Current()->instruments; - With fresh_build_ctx(context); + With fresh_build_ctx(transform::PassContext::Create()); Map dict = (module_->attrs.defined()) ? Map(module_->attrs.CopyOnWrite()->dict) From 2292229b4459e0fa704036949aa6a7ace8339587 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Nov 2022 15:45:38 -0600 Subject: [PATCH 25/25] Updated TVMScript representation of Ramp Previous parser contextually interpreted buffer subscripts as Ramp nodes if they occurred as part of an expression. For new parser, need to use an explicit step size to generate a Ramp. --- tests/python/unittest/test_arith_domain_touched.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 81bda403274f..9f7eee096362 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -78,7 +78,7 @@ def func(a: T.handle, b: T.handle): B = T.match_buffer(b, (n * m,)) for i in T.serial(n): - A[i * m : (i + 1) * m] = A[i * m : (i + 1) * m] + B[i * m : (i + 1) * m] + A[i * m : (i + 1) * m : 1] = A[i * m : (i + 1) * m : 1] + B[i * m : (i + 1) * m : 1] a, b = [func.buffer_map[var] for var in func.params]